|
3 | 3 | // |
4 | 4 | // Converts from Python classes we define in pure python so we can use dataclasses |
5 | 5 | // to represent the input types |
| 6 | +// TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead. |
6 | 7 | use pyo3::prelude::*; |
7 | 8 |
|
8 | 9 | // Execute the block and wrap the error in a type error |
@@ -175,3 +176,64 @@ impl From<WrappedLiteral> for egg_smol::ast::Literal { |
175 | 176 | other.0 |
176 | 177 | } |
177 | 178 | } |
| 179 | + |
| 180 | +// Wrapped version of Rewrite |
| 181 | +pub struct WrappedRewrite(egg_smol::ast::Rewrite); |
| 182 | + |
| 183 | +impl FromPyObject<'_> for WrappedRewrite { |
| 184 | + fn extract(obj: &'_ PyAny) -> PyResult<Self> { |
| 185 | + wrap_error("Rewrite", obj, || { |
| 186 | + Ok(WrappedRewrite(egg_smol::ast::Rewrite { |
| 187 | + lhs: obj.getattr("lhs")?.extract::<WrappedExpr>()?.into(), |
| 188 | + rhs: obj.getattr("rhs")?.extract::<WrappedExpr>()?.into(), |
| 189 | + conditions: obj |
| 190 | + .getattr("conditions")? |
| 191 | + .extract::<Vec<WrappedFact>>()? |
| 192 | + .into_iter() |
| 193 | + .map(|x| x.into()) |
| 194 | + .collect(), |
| 195 | + })) |
| 196 | + }) |
| 197 | + } |
| 198 | +} |
| 199 | + |
| 200 | +impl From<WrappedRewrite> for egg_smol::ast::Rewrite { |
| 201 | + fn from(other: WrappedRewrite) -> Self { |
| 202 | + other.0 |
| 203 | + } |
| 204 | +} |
| 205 | + |
| 206 | +// Wrapped version of Fact |
| 207 | +pub struct WrappedFact(egg_smol::ast::Fact); |
| 208 | + |
| 209 | +impl FromPyObject<'_> for WrappedFact { |
| 210 | + fn extract(obj: &'_ PyAny) -> PyResult<Self> { |
| 211 | + wrap_error("Fact", obj, || { |
| 212 | + extract_fact_eq(obj) |
| 213 | + .or_else(|_| extract_fact_fact(obj)) |
| 214 | + .map(WrappedFact) |
| 215 | + }) |
| 216 | + } |
| 217 | +} |
| 218 | + |
| 219 | +fn extract_fact_eq(obj: &PyAny) -> PyResult<egg_smol::ast::Fact> { |
| 220 | + Ok(egg_smol::ast::Fact::Eq( |
| 221 | + obj.getattr("exprs")? |
| 222 | + .extract::<Vec<WrappedExpr>>()? |
| 223 | + .into_iter() |
| 224 | + .map(|x| x.into()) |
| 225 | + .collect(), |
| 226 | + )) |
| 227 | +} |
| 228 | + |
| 229 | +fn extract_fact_fact(obj: &PyAny) -> PyResult<egg_smol::ast::Fact> { |
| 230 | + Ok(egg_smol::ast::Fact::Fact( |
| 231 | + obj.getattr("expr")?.extract::<WrappedExpr>()?.into(), |
| 232 | + )) |
| 233 | +} |
| 234 | + |
| 235 | +impl From<WrappedFact> for egg_smol::ast::Fact { |
| 236 | + fn from(other: WrappedFact) -> Self { |
| 237 | + other.0 |
| 238 | + } |
| 239 | +} |
0 commit comments