Skip to content

Commit 09f9214

Browse files
david-plweinbe58
andauthored
Rewrite py.Mult to squin.Mult in squin kernel (#212)
Seemed sufficiently different from #207, so I created a new branch for it. @Roger-luo two things: * I'm not sure whether this should be a `Fixpoint`. * Should this also be part of the `wire` kernel? --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent cdf3280 commit 09f9214

File tree

8 files changed

+243
-10
lines changed

8 files changed

+243
-10
lines changed

src/bloqade/squin/groups.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from kirin.rewrite.walk import Walk
55

66
from . import op, wire, qubit
7+
from .op.rewrite import PyMultToSquinMult
78
from .rewrite.measure_desugar import MeasureDesugarRule
89

910

@@ -13,16 +14,21 @@ def kernel(self):
1314
typeinfer_pass = passes.TypeInfer(self)
1415
ilist_desugar_pass = ilist.IListDesugar(self)
1516
measure_desugar_pass = Walk(MeasureDesugarRule())
17+
py_mult_to_mult_pass = PyMultToSquinMult(self)
1618

1719
def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1820
method.verify()
1921
if fold:
2022
fold_pass.fixpoint(method)
2123

24+
py_mult_to_mult_pass(method)
25+
2226
if typeinfer:
2327
typeinfer_pass(method)
2428
measure_desugar_pass.rewrite(method.code)
29+
2530
ilist_desugar_pass(method)
31+
2632
if typeinfer:
2733
typeinfer_pass(method) # fix types after desugaring
2834
method.verify_type()
@@ -32,7 +38,9 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
3238

3339
@ir.dialect_group(structural_no_opt.union([op, wire]))
3440
def wired(self):
41+
py_mult_to_mult_pass = PyMultToSquinMult(self)
42+
3543
def run_pass(method):
36-
pass
44+
py_mult_to_mult_pass(method)
3745

3846
return run_pass

src/bloqade/squin/op/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.prelude import structural_no_opt as _structural_no_opt
33
from kirin.lowering import wraps as _wraps
44

5-
from . import stmts as stmts, types as types
5+
from . import stmts as stmts, types as types, rewrite as rewrite
66
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
77
from ._dialect import dialect as dialect
88

src/bloqade/squin/op/complex.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/bloqade/squin/op/number.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import numbers
2+
3+
from kirin.ir.attrs.types import PyClass
4+
5+
NumberType = PyClass(numbers.Number)

src/bloqade/squin/op/rewrite.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Rewrite py.binop.mult to Mult stmt"""
2+
3+
from kirin import ir
4+
from kirin.passes import Pass
5+
from kirin.rewrite import Walk
6+
from kirin.dialects import py
7+
from kirin.rewrite.abc import RewriteRule, RewriteResult
8+
9+
from .stmts import Mult, Scale
10+
from .types import OpType
11+
12+
13+
class _PyMultToSquinMult(RewriteRule):
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16+
if not isinstance(node, py.Mult):
17+
return RewriteResult()
18+
19+
lhs_is_op = node.lhs.type.is_subseteq(OpType)
20+
rhs_is_op = node.rhs.type.is_subseteq(OpType)
21+
22+
if not lhs_is_op and not rhs_is_op:
23+
return RewriteResult()
24+
25+
if lhs_is_op and rhs_is_op:
26+
mult = Mult(node.lhs, node.rhs)
27+
node.replace_by(mult)
28+
return RewriteResult(has_done_something=True)
29+
30+
if lhs_is_op:
31+
scale = Scale(node.lhs, node.rhs)
32+
node.replace_by(scale)
33+
return RewriteResult(has_done_something=True)
34+
35+
if rhs_is_op:
36+
scale = Scale(node.rhs, node.lhs)
37+
node.replace_by(scale)
38+
return RewriteResult(has_done_something=True)
39+
40+
return RewriteResult()
41+
42+
43+
class PyMultToSquinMult(Pass):
44+
45+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
46+
return Walk(_PyMultToSquinMult()).rewrite(mt.code)

src/bloqade/squin/op/stmts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from kirin.decl import info, statement
33

44
from .types import OpType
5+
from .number import NumberType
56
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
6-
from .complex import Complex
77
from ._dialect import dialect
88

99

@@ -54,7 +54,7 @@ class Scale(CompositeOp):
5454
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
5555
is_unitary: bool = info.attribute(default=False)
5656
op: ir.SSAValue = info.argument(OpType)
57-
factor: ir.SSAValue = info.argument(Complex)
57+
factor: ir.SSAValue = info.argument(NumberType)
5858
result: ir.ResultValue = info.result(OpType)
5959

6060

src/bloqade/squin/op/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import overload
2+
13
from kirin import types
24

35

@@ -6,5 +8,17 @@ class Op:
68
def __matmul__(self, other: "Op") -> "Op":
79
raise NotImplementedError("@ can only be used within a squin kernel program")
810

11+
@overload
12+
def __mul__(self, other: "Op") -> "Op": ...
13+
14+
@overload
15+
def __mul__(self, other: complex) -> "Op": ...
16+
17+
def __mul__(self, other) -> "Op":
18+
raise NotImplementedError("@ can only be used within a squin kernel program")
19+
20+
def __rmul__(self, other: complex) -> "Op":
21+
raise NotImplementedError("@ can only be used within a squin kernel program")
22+
923

1024
OpType = types.PyClass(Op)

test/squin/test_mult_rewrite.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from kirin.types import PyClass
2+
from kirin.dialects import py, func
3+
4+
from bloqade import squin
5+
6+
7+
def test_mult_rewrite():
8+
9+
@squin.kernel
10+
def helper(x: squin.op.types.Op, y: squin.op.types.Op):
11+
return x * y
12+
13+
@squin.kernel
14+
def main():
15+
q = squin.qubit.new(1)
16+
x = squin.op.x()
17+
y = squin.op.y()
18+
z = x * y
19+
t = helper(x, z)
20+
21+
squin.qubit.apply(t, q)
22+
return q
23+
24+
helper.print()
25+
main.print()
26+
27+
assert isinstance(helper.code, func.Function)
28+
29+
helper_stmts = list(helper.code.body.stmts())
30+
assert len(helper_stmts) == 2 # [Mult(), Return()]
31+
assert isinstance(helper_stmts[0], squin.op.stmts.Mult)
32+
33+
assert isinstance(main.code, func.Function)
34+
35+
count_mults_in_main = 0
36+
for stmt in main.code.body.stmts():
37+
assert not isinstance(stmt, py.Mult)
38+
39+
count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult)
40+
41+
assert count_mults_in_main == 1
42+
43+
44+
def test_scale_rewrite():
45+
46+
@squin.kernel
47+
def simple_rmul():
48+
x = squin.op.x()
49+
y = 2 * x
50+
return y
51+
52+
simple_rmul.print()
53+
54+
assert isinstance(simple_rmul.code, func.Function)
55+
56+
simple_rmul_stmts = list(simple_rmul.code.body.stmts())
57+
assert any(
58+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_rmul_stmts)
59+
)
60+
assert not any(
61+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_rmul_stmts)
62+
)
63+
assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_rmul_stmts))
64+
65+
@squin.kernel
66+
def simple_lmul():
67+
x = squin.op.x()
68+
y = x * 2
69+
return y
70+
71+
simple_lmul.print()
72+
73+
assert isinstance(simple_lmul.code, func.Function)
74+
75+
simple_lmul_stmts = list(simple_lmul.code.body.stmts())
76+
assert any(
77+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_lmul_stmts)
78+
)
79+
assert not any(
80+
map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_lmul_stmts)
81+
)
82+
assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_lmul_stmts))
83+
84+
@squin.kernel
85+
def scale_mult():
86+
x = squin.op.x()
87+
y = squin.op.y()
88+
return 2 * (x * y)
89+
90+
assert isinstance(scale_mult.code, func.Function)
91+
92+
scale_mult_stmts = list(scale_mult.code.body.stmts())
93+
assert (
94+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult_stmts))
95+
== 1
96+
)
97+
assert (
98+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult_stmts))
99+
== 1
100+
)
101+
102+
@squin.kernel
103+
def scale_mult2():
104+
x = squin.op.x()
105+
y = squin.op.y()
106+
return 2 * x * y
107+
108+
scale_mult2.print()
109+
110+
assert isinstance(scale_mult2.code, func.Function)
111+
112+
scale_mult2_stmts = list(scale_mult2.code.body.stmts())
113+
assert (
114+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult2_stmts))
115+
== 1
116+
)
117+
assert (
118+
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts))
119+
== 1
120+
)
121+
122+
123+
def test_scale_types():
124+
@squin.kernel
125+
def simple_lmul():
126+
x = squin.op.x()
127+
y = x * (2 + 0j)
128+
return y
129+
130+
@squin.kernel
131+
def simple_rmul():
132+
x = squin.op.x()
133+
y = 2.1 * x
134+
return y
135+
136+
@squin.kernel
137+
def nested_rmul():
138+
x = squin.op.x()
139+
y = squin.op.y()
140+
return 2 * x * y
141+
142+
@squin.kernel
143+
def nested_rmul2():
144+
x = squin.op.x()
145+
y = squin.op.y()
146+
return 2 * (x * y)
147+
148+
@squin.kernel
149+
def nested_lmul():
150+
x = squin.op.x()
151+
y = squin.op.y()
152+
return x * y * 2.0j
153+
154+
def check_stmt_type(code, typ):
155+
for stmt in code.body.stmts():
156+
if isinstance(stmt, func.Return):
157+
continue
158+
is_op = stmt.result.type.is_subseteq(squin.op.types.OpType)
159+
is_num = stmt.result.type.is_equal(PyClass(typ))
160+
assert is_op or is_num
161+
162+
check_stmt_type(simple_lmul.code, complex)
163+
check_stmt_type(simple_rmul.code, float)
164+
check_stmt_type(nested_rmul.code, int)
165+
check_stmt_type(nested_rmul2.code, int)
166+
check_stmt_type(nested_lmul.code, complex)

0 commit comments

Comments
 (0)