Skip to content

Commit 020845a

Browse files
committed
Fix Scale rewrite
1 parent 5884d0b commit 020845a

File tree

3 files changed

+100
-3
lines changed

3 files changed

+100
-3
lines changed

src/bloqade/squin/op/rewrite.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from kirin.dialects import py
77
from kirin.rewrite.abc import RewriteRule, RewriteResult
88

9-
from .stmts import Mult, Scale
9+
from .stmts import Mult, Scale, Operator
1010
from .types import OpType
1111

1212

@@ -17,11 +17,17 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1717
return RewriteResult()
1818

1919
lhs_is_op = node.lhs.type.is_subseteq(OpType)
20-
rhs_is_op = node.lhs.type.is_subseteq(OpType)
20+
rhs_is_op = node.rhs.type.is_subseteq(OpType)
2121

2222
if not lhs_is_op and not rhs_is_op:
2323
return RewriteResult()
2424

25+
if isinstance(node.lhs, ir.ResultValue):
26+
lhs_is_op = isinstance(node.lhs.stmt, Operator)
27+
28+
if isinstance(node.rhs, ir.ResultValue):
29+
rhs_is_op = isinstance(node.rhs.stmt, Operator)
30+
2531
if lhs_is_op and rhs_is_op:
2632
mult = Mult(node.lhs, node.rhs)
2733
node.replace_by(mult)

src/bloqade/squin/op/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import overload
2+
13
from kirin import types
24

35

@@ -6,7 +8,16 @@ class Op:
68
def __matmul__(self, other: "Op") -> "Op":
79
raise NotImplementedError("@ can only be used within a squin kernel program")
810

9-
def __mul__(self, other: "Op") -> "Op":
11+
@overload
12+
def __mul__(self, other: "Op") -> "Op": ...
13+
14+
@overload
15+
def __mul__(self, other: int | float | complex) -> "Op": ...
16+
17+
def __mul__(self, other) -> "Op":
18+
raise NotImplementedError("@ can only be used within a squin kernel program")
19+
20+
def __rmul__(self, other: int | float | complex) -> "Op":
1021
raise NotImplementedError("@ can only be used within a squin kernel program")
1122

1223

test/squin/test_mult_rewrite.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def main():
2121
return q
2222

2323
helper.print()
24+
main.print()
2425

2526
assert isinstance(helper.code, func.Function)
2627

@@ -37,3 +38,82 @@ def main():
3738
count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult)
3839

3940
assert count_mults_in_main == 1
41+
42+
43+
def test_scale_rewrite():
44+
45+
@squin.kernel
46+
def simple_rmul():
47+
x = squin.op.x()
48+
y = 2 * x
49+
return y
50+
51+
simple_rmul.print()
52+
53+
assert isinstance(simple_rmul.code, func.Function)
54+
55+
simple_rmul_stmts = list(simple_rmul.code.body.stmts())
56+
assert any(
57+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_rmul_stmts)
58+
)
59+
assert not any(
60+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_rmul_stmts)
61+
)
62+
assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_rmul_stmts))
63+
64+
@squin.kernel
65+
def simple_lmul():
66+
x = squin.op.x()
67+
y = x * 2
68+
return y
69+
70+
simple_lmul.print()
71+
72+
assert isinstance(simple_lmul.code, func.Function)
73+
74+
simple_lmul_stmts = list(simple_lmul.code.body.stmts())
75+
assert any(
76+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_lmul_stmts)
77+
)
78+
assert not any(
79+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_lmul_stmts)
80+
)
81+
assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_lmul_stmts))
82+
83+
@squin.kernel
84+
def scale_mult():
85+
x = squin.op.x()
86+
y = squin.op.y()
87+
return 2 * (x * y)
88+
89+
assert isinstance(scale_mult.code, func.Function)
90+
91+
scale_mult_stmts = list(scale_mult.code.body.stmts())
92+
assert (
93+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult_stmts))
94+
== 1
95+
)
96+
assert (
97+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult_stmts))
98+
== 1
99+
)
100+
101+
@squin.kernel
102+
def scale_mult2():
103+
x = squin.op.x()
104+
y = squin.op.y()
105+
return 2 * x * y
106+
107+
scale_mult2.print()
108+
109+
assert isinstance(scale_mult2.code, func.Function)
110+
111+
scale_mult2_stmts = list(scale_mult2.code.body.stmts())
112+
assert (
113+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult2_stmts))
114+
== 1
115+
)
116+
assert (
117+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts))
118+
== 1
119+
)

0 commit comments

Comments
 (0)