Skip to content

Commit 298d9e1

Browse files
authored
fix more bugs for bloqade qft (#245)
1 parent 480ffc5 commit 298d9e1

File tree

6 files changed

+18
-11
lines changed

6 files changed

+18
-11
lines changed

src/kirin/dialects/py/binop/typeinfer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def subf(self, *_):
2727
def subi(self, *_):
2828
return (types.Int,)
2929

30-
@interp.impl(stmts.Sub, types.Float, types.Float)
31-
@interp.impl(stmts.Sub, types.Float, types.Int)
32-
@interp.impl(stmts.Sub, types.Int, types.Float)
30+
@interp.impl(stmts.Mult, types.Float, types.Float)
31+
@interp.impl(stmts.Mult, types.Float, types.Int)
32+
@interp.impl(stmts.Mult, types.Int, types.Float)
3333
def multf(self, *_):
3434
return (types.Float,)
3535

src/kirin/dialects/scf/absint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from kirin import ir, types, interp
1+
from kirin import ir, interp
22
from kirin.analysis import const
33
from kirin.dialects import func
44

@@ -53,10 +53,10 @@ def _infer_if_else_cond(
5353
body_block = body.blocks[0]
5454
body_term = body_block.last_stmt
5555
if isinstance(body_term, func.Return):
56-
frame.worklist.append(interp.Successor(body_block, types.Bool))
56+
frame.worklist.append(interp.Successor(body_block, frame.get(stmt.cond)))
5757
return
5858

5959
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
6060
body_frame.entries.update(frame.entries)
61-
body_frame.set(body_block.args[0], types.Bool)
62-
return interp_.run_ssacfg_region(body_frame, stmt.then_body)
61+
body_frame.set(body_block.args[0], frame.get(stmt.cond))
62+
return interp_.run_ssacfg_region(body_frame, body)

src/kirin/dialects/scf/typeinfer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,19 @@ def for_loop(
3737

3838
eltype = interp_.run_stmt(ElType(ir.TestValue()), (iterable,))
3939
if not isinstance(eltype, tuple): # error
40-
return (interp_.lattice.bottom(),)
40+
return
4141
item = eltype[0]
4242
frame.set_values(block_args, (item,) + loop_vars)
4343

4444
if isinstance(body_block.last_stmt, func.Return):
4545
frame.worklist.append(interp.Successor(body_block, item, *loop_vars))
4646
return # if terminate is Return, there is no result
4747

48-
loop_vars_ = interp_.run_ssacfg_region(frame, stmt.body)
48+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
49+
body_frame.entries.update(frame.entries)
50+
loop_vars_ = interp_.run_ssacfg_region(body_frame, stmt.body)
51+
52+
frame.entries.update(body_frame.entries)
4953
if isinstance(loop_vars_, interp.ReturnValue):
5054
return loop_vars_
5155
elif isinstance(loop_vars_, tuple):

src/kirin/dialects/scf/unroll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
6262
# TODO: support for PartialTuple and IList with known length
6363
if not isinstance(hint := node.iterable.hints.get("const"), const.Value):
6464
return RewriteResult()
65-
65+
print(hint)
6666
loop_vars = node.initializers
6767
for item in hint.data:
6868
body = node.body.clone()

src/kirin/passes/abc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __call__(self, mt: Method) -> RewriteResult:
3333
def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
3434
result = RewriteResult()
3535
for _ in range(max_iter):
36-
result = self.unsafe_run(mt).join(result)
36+
result_ = self.unsafe_run(mt)
37+
result = result_.join(result)
3738
if not result.has_done_something:
3839
break
3940
mt.code.verify()

src/kirin/rewrite/wrap_const.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def wrap(self, value: ir.SSAValue) -> bool:
2727
const_hint = value.hints.get("const")
2828
if const_hint and isinstance(const_hint, const.Result):
2929
const_result = result.join(const_hint)
30+
if const_result.is_equal(const_hint):
31+
return False
3032
else:
3133
const_result = result
3234
value.hints["const"] = const_result

0 commit comments

Comments
 (0)