|
| 1 | +from kirin import ir |
| 2 | +from kirin.dialects.py import Add, Div, Sub, Mult, BinOp |
| 3 | +from kirin.rewrite.abc import RewriteRule, RewriteResult |
| 4 | +from kirin.ir.attrs.types import Generic, PyClass |
| 5 | +from kirin.dialects.ilist.runtime import IList |
| 6 | + |
| 7 | +from ..stmts import add as vadd, div as vdiv, mul as vmul, sub as vsub |
| 8 | +from .._dialect import dialect |
| 9 | + |
| 10 | + |
| 11 | +@dialect.post_inference |
| 12 | +class DesugarBinOp(RewriteRule): |
| 13 | + |
| 14 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 15 | + match node: |
| 16 | + case BinOp(): |
| 17 | + match (node.lhs.type, node.rhs.type): |
| 18 | + case (PyClass(lhs_typ), Generic(PyClass(rhs_typ))): |
| 19 | + if (lhs_typ is float or lhs_typ is int) and rhs_typ == IList: |
| 20 | + return self.replace_binop(node) |
| 21 | + case (Generic(PyClass(lhs_typ)), PyClass(rhs_typ)): |
| 22 | + if lhs_typ is IList and (rhs_typ is float or rhs_typ is int): |
| 23 | + return self.replace_binop(node) |
| 24 | + case _: |
| 25 | + return RewriteResult() |
| 26 | + |
| 27 | + case _: |
| 28 | + return RewriteResult() |
| 29 | + |
| 30 | + def replace_binop(self, node: ir.Statement): |
| 31 | + match node: |
| 32 | + case Add(): |
| 33 | + node.replace_by(vadd(lhs=node.lhs, rhs=node.rhs)) |
| 34 | + return RewriteResult(has_done_something=True) |
| 35 | + case Sub(): |
| 36 | + node.replace_by(vsub(lhs=node.lhs, rhs=node.rhs)) |
| 37 | + return RewriteResult(has_done_something=True) |
| 38 | + case Mult(): |
| 39 | + node.replace_by(vmul(lhs=node.lhs, rhs=node.rhs)) |
| 40 | + return RewriteResult(has_done_something=True) |
| 41 | + case Div(): |
| 42 | + node.replace_by(vdiv(lhs=node.lhs, rhs=node.rhs)) |
| 43 | + return RewriteResult(has_done_something=True) |
| 44 | + case _: |
| 45 | + return RewriteResult() |
0 commit comments