Skip to content

Commit 2c78a87

Browse files
committed
Fix type inference
1 parent 020845a commit 2c78a87

File tree

7 files changed

+52
-18
lines changed

7 files changed

+52
-18
lines changed

src/bloqade/squin/groups.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1818
if fold:
1919
fold_pass.fixpoint(method)
2020

21+
py_mult_to_mult_pass(method)
22+
2123
if typeinfer:
2224
typeinfer_pass(method)
2325

2426
ilist_desugar_pass(method)
25-
py_mult_to_mult_pass(method)
2627

2728
if typeinfer:
2829
typeinfer_pass(method) # fix types after desugaring

src/bloqade/squin/op/complex.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/bloqade/squin/op/number.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import numbers
2+
3+
from kirin.ir.attrs.types import PyClass
4+
5+
Number = PyClass(numbers.Number)

src/bloqade/squin/op/rewrite.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from kirin.dialects import py
77
from kirin.rewrite.abc import RewriteRule, RewriteResult
88

9-
from .stmts import Mult, Scale, Operator
9+
from .stmts import Mult, Scale
1010
from .types import OpType
1111

1212

@@ -22,12 +22,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2222
if not lhs_is_op and not rhs_is_op:
2323
return RewriteResult()
2424

25-
if isinstance(node.lhs, ir.ResultValue):
26-
lhs_is_op = isinstance(node.lhs.stmt, Operator)
27-
28-
if isinstance(node.rhs, ir.ResultValue):
29-
rhs_is_op = isinstance(node.rhs.stmt, Operator)
30-
3125
if lhs_is_op and rhs_is_op:
3226
mult = Mult(node.lhs, node.rhs)
3327
node.replace_by(mult)

src/bloqade/squin/op/stmts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from kirin.decl import info, statement
33

44
from .types import OpType
5+
from .number import Number
56
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
6-
from .complex import Complex
77
from ._dialect import dialect
88

99

@@ -54,7 +54,7 @@ class Scale(CompositeOp):
5454
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
5555
is_unitary: bool = info.attribute(default=False)
5656
op: ir.SSAValue = info.argument(OpType)
57-
factor: ir.SSAValue = info.argument(Complex)
57+
factor: ir.SSAValue = info.argument(Number)
5858
result: ir.ResultValue = info.result(OpType)
5959

6060

src/bloqade/squin/op/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ def __matmul__(self, other: "Op") -> "Op":
1212
def __mul__(self, other: "Op") -> "Op": ...
1313

1414
@overload
15-
def __mul__(self, other: int | float | complex) -> "Op": ...
15+
def __mul__(self, other: complex) -> "Op": ...
1616

1717
def __mul__(self, other) -> "Op":
1818
raise NotImplementedError("@ can only be used within a squin kernel program")
1919

20-
def __rmul__(self, other: int | float | complex) -> "Op":
20+
def __rmul__(self, other: complex) -> "Op":
2121
raise NotImplementedError("@ can only be used within a squin kernel program")
2222

2323

test/squin/test_mult_rewrite.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from kirin.types import PyClass
12
from kirin.dialects import py, func
23

34
from bloqade import squin
@@ -117,3 +118,42 @@ def scale_mult2():
117118
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts))
118119
== 1
119120
)
121+
122+
123+
def test_scale_types():
124+
@squin.kernel
125+
def simple_lmul():
126+
x = squin.op.x()
127+
y = x * (2 + 0j)
128+
return y
129+
130+
@squin.kernel
131+
def simple_rmul():
132+
x = squin.op.x()
133+
y = 2.1 * x
134+
return y
135+
136+
@squin.kernel
137+
def nested_rmul():
138+
x = squin.op.x()
139+
y = squin.op.y()
140+
return 2 * x * y
141+
142+
@squin.kernel
143+
def nested_lmul():
144+
x = squin.op.x()
145+
y = squin.op.y()
146+
return x * y * 2.0j
147+
148+
def check_stmt_type(code, typ):
149+
for stmt in code.body.stmts():
150+
if isinstance(stmt, func.Return):
151+
continue
152+
is_op = stmt.result.type.is_subseteq(squin.op.types.OpType)
153+
is_num = stmt.result.type.is_equal(PyClass(typ))
154+
assert is_op or is_num
155+
156+
check_stmt_type(simple_lmul.code, complex)
157+
check_stmt_type(simple_rmul.code, float)
158+
check_stmt_type(nested_rmul.code, int)
159+
check_stmt_type(nested_lmul.code, complex)

0 commit comments

Comments
 (0)