Skip to content

Commit 203249f

Browse files
committed
add rewrite rule
1 parent acdd5d3 commit 203249f

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from kirin import ir
2+
from kirin.dialects.py import Add, Div, Sub, Mult, BinOp
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
from kirin.ir.attrs.types import Generic, PyClass
5+
from kirin.dialects.ilist.runtime import IList
6+
7+
from ..stmts import add as vadd, div as vdiv, mul as vmul, sub as vsub
8+
from .._dialect import dialect
9+
10+
11+
@dialect.post_inference
12+
class DesugarBinOp(RewriteRule):
13+
14+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
15+
match node:
16+
case BinOp():
17+
match (node.lhs.type, node.rhs.type):
18+
case (PyClass(lhs_typ), Generic(PyClass(rhs_typ))):
19+
if (lhs_typ is float or lhs_typ is int) and rhs_typ == IList:
20+
return self.replace_binop(node)
21+
case (Generic(PyClass(lhs_typ)), PyClass(rhs_typ)):
22+
if lhs_typ is IList and (rhs_typ is float or rhs_typ is int):
23+
return self.replace_binop(node)
24+
case _:
25+
return RewriteResult()
26+
27+
case _:
28+
return RewriteResult()
29+
30+
def replace_binop(self, node: ir.Statement):
31+
match node:
32+
case Add():
33+
node.replace_by(vadd(lhs=node.lhs, rhs=node.rhs))
34+
return RewriteResult(has_done_something=True)
35+
case Sub():
36+
node.replace_by(vsub(lhs=node.lhs, rhs=node.rhs))
37+
return RewriteResult(has_done_something=True)
38+
case Mult():
39+
node.replace_by(vmul(lhs=node.lhs, rhs=node.rhs))
40+
return RewriteResult(has_done_something=True)
41+
case Div():
42+
node.replace_by(vdiv(lhs=node.lhs, rhs=node.rhs))
43+
return RewriteResult(has_done_something=True)
44+
case _:
45+
return RewriteResult()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# from typing import Any
2+
#
3+
# import numpy as np
4+
#
5+
# from kirin import types
6+
from kirin.prelude import basic
7+
from kirin.dialects import vmath
8+
9+
10+
@basic.union([vmath])
11+
def add_kernel(x, y):
12+
return x + y
13+
14+
15+
@basic.union([vmath])
16+
def add_two_lists():
17+
return add_kernel(x=[0, 1, 2], y=[3, 4, 5])
18+
19+
20+
@basic.union([vmath])(aggressive=True)
21+
def add_scalar_lhs():
22+
return add_kernel(x=3.0, y=[3.0, 4, 5])
23+
24+
25+
def test_add_scalar_lhs():
26+
# out = add_scalar_lhs()
27+
import ipdb
28+
29+
ipdb.set_trace()

0 commit comments

Comments
 (0)