Skip to content

Commit 10b4f47

Browse files
authored
Rafaelha/fix has_done_something and fixpoint logic (#545)
This PR fixes several issues related to fixpoint rewrites: 1. A subtle bug in `Pass.fixpoint` lead to rewrites being applied for `max_iter` iterations, even if they had already converged. 2. Several compiler passes were incorrectly returning `has_done_something=True`, even if they had not done anything, leading to maxed out fixpoint loops. 3. The CSE rewrite would prematurely exit, leading to more fixpoint iterations. I checked that all tests in the kirin test suite lead to fixpoint rewrites that converge before reaching `max_iter`. It's still possible that there are more rewrite passes that need to be fixed. We should consider raising an error or at least a warning if a fixpoint pass does not converge, as this will immediately point us to any such issues in the future. **The default `max_iter` is set to 32. With this fix, many passes now iterate only once or twice instead of 32 times. So in many cases there is a 32x speedup. Some passes also contain a fixpoint within a fixpoint. They can now become 1000x faster.** I noticed these issues when running the following script ```python @qasm2.extended # O(n) scaling def main_fun(): qreg = qasm2.qreg(1) for _ in range(n): qasm2.h(qreg[0]) QASM2().emit(main_fun) # ~O(n^1.5) ``` This script runs really slow for large `n`. But even more important, recording the duration of compilation and QASM2 emission, I found non-linear scaling ~ O(n^1.5). Any compiler pass must be linear in the program size (or at worst n log n). <img width="331" height="250" alt="image" src="https://github.com/user-attachments/assets/c914db82-e2c6-4821-8a0e-b00ab19b432c" /> With the fixes in this PR, the execution time of this script is significantly faster and the computational complexity is back to the expected O(n): <img width="331" height="245" alt="image" src="https://github.com/user-attachments/assets/35480c95-29b5-4117-b796-c4cc891d1cb0" /> These issues also affect the runtime of @cduck's flair-to-squin pipeline - but I think there are additional slowdowns that I'm still looking into.
1 parent 537525f commit 10b4f47

File tree

12 files changed

+55
-36
lines changed

12 files changed

+55
-36
lines changed

src/kirin/dialects/ilist/rewrite/const.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2727
typ = result.type
2828
data = hint.data
2929
if isinstance(typ, types.PyClass) and typ.is_subseteq(types.PyClass(IList)):
30-
has_done_something = self._rewrite_IList_type(result, data)
30+
has_done_something = has_done_something or self._rewrite_IList_type(
31+
result, data
32+
)
3133
elif isinstance(typ, types.Generic) and typ.body.is_subseteq(
3234
types.PyClass(IList)
3335
):
34-
has_done_something = self._rewrite_IList_type(result, data)
36+
has_done_something = has_done_something or self._rewrite_IList_type(
37+
result, data
38+
)
3539
return RewriteResult(has_done_something=has_done_something)
3640

3741
def rewrite_Constant(self, node: Constant) -> RewriteResult:
@@ -53,6 +57,13 @@ def _rewrite_IList_type(self, result: ir.SSAValue, data):
5357
for elem in data.data[1:]:
5458
elem_type = elem_type.join(types.PyClass(type(elem)))
5559

56-
result.type = IListType[elem_type, types.Literal(len(data.data))]
57-
result.hints["const"] = const.Value(data)
60+
new_type = IListType[elem_type, types.Literal(len(data.data))]
61+
new_hint = const.Value(data)
62+
63+
# Check if type and hint are already correct
64+
if result.type == new_type and result.hints.get("const") == new_hint:
65+
return False
66+
67+
result.type = new_type
68+
result.hints["const"] = new_hint
5869
return True

src/kirin/dialects/ilist/rewrite/hint_len.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
3232
if (coll_len := self._get_collection_len(node.value)) is None:
3333
return RewriteResult()
3434

35-
node.result.hints["const"] = const.Value(coll_len)
35+
existing_hint = node.result.hints.get("const")
36+
new_hint = const.Value(coll_len)
3637

38+
if existing_hint is not None and new_hint.is_structurally_equal(existing_hint):
39+
return RewriteResult()
40+
41+
node.result.hints["const"] = new_hint
3742
return RewriteResult(has_done_something=True)

src/kirin/dialects/ilist/rewrite/list.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class List2IList(RewriteRule):
1313
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
1414
has_done_something = False
1515
for arg in node.args:
16-
has_done_something = self._rewrite_SSAValue_type(arg)
16+
has_done_something = has_done_something or self._rewrite_SSAValue_type(arg)
1717
return RewriteResult(has_done_something=has_done_something)
1818

1919
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
@@ -25,7 +25,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2525
)
2626

2727
for result in node.results:
28-
has_done_something = self._rewrite_SSAValue_type(result)
28+
has_done_something = has_done_something or self._rewrite_SSAValue_type(
29+
result
30+
)
2931

3032
return RewriteResult(has_done_something=has_done_something)
3133

src/kirin/passes/abc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
3636
for _ in range(max_iter):
3737
result_ = self.unsafe_run(mt)
3838
result = result_.join(result)
39-
if not result.has_done_something:
39+
if not result_.has_done_something:
4040
break
4141
mt.verify()
4242
return result

src/kirin/rewrite/apply_type.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
1313
has_done_something = False
1414
for arg in node.args:
1515
if arg in self.results:
16-
arg.type = self.results[arg]
17-
has_done_something = True
16+
arg_type = self.results[arg]
17+
if arg.type != arg_type:
18+
arg.type = arg_type
19+
has_done_something = True
1820

1921
return RewriteResult(has_done_something=has_done_something)
2022

2123
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2224
has_done_something = False
2325
for result in node._results:
2426
if result in self.results:
25-
result.type = self.results[result]
26-
has_done_something = True
27+
arg_type = self.results[result]
28+
if result.type != arg_type:
29+
result.type = arg_type
30+
has_done_something = True
2731

2832
if (trait := node.get_trait(ir.HasSignature)) is not None and (
2933
callable_trait := node.get_trait(ir.CallableStmtInterface)

src/kirin/rewrite/cse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class CommonSubexpressionElimination(RewriteRule):
6161

6262
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
6363
seen: dict[Info, ir.Statement] = {}
64+
has_done_something = False
6465

6566
for stmt in node.stmts:
6667
if not stmt.has_trait(ir.Pure):
@@ -81,10 +82,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
8182
for result, old_result in zip(stmt._results, old_stmt.results):
8283
result.replace_by(old_result)
8384
stmt.delete()
84-
return RewriteResult(has_done_something=True)
85+
has_done_something = True
8586
else:
8687
seen[info] = stmt
87-
return RewriteResult()
88+
return RewriteResult(has_done_something=has_done_something)
8889

8990
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
9091
if not node.regions:

src/kirin/rewrite/fold.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2929
has_done_something = False
3030
for old_result in node.results:
3131
if (value := self.get_const(old_result)) is not None:
32+
if not old_result.uses:
33+
continue
3234
stmt = Constant(value.data)
3335
stmt.insert_before(node)
3436
old_result.replace_by(stmt.result)

src/kirin/rewrite/walk.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def rewrite(self, node: IRNode) -> RewriteResult:
3131
# NOTE: because the rewrite pass may mutate the node
3232
# thus we need to save the list of nodes to be processed
3333
# first before we start processing them
34+
assert self.worklist.is_empty()
35+
3436
self.populate_worklist(node)
3537
has_done_something = False
3638
subnode = self.worklist.pop()

src/kirin/rewrite/wrap_const.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
4848
has_done_something = True
4949

5050
if (
51-
trait := node.get_trait(ir.MaybePure)
52-
) and node in self.frame.should_be_pure:
51+
(trait := node.get_trait(ir.MaybePure))
52+
and node in self.frame.should_be_pure
53+
and not trait.is_pure(node)
54+
):
5355
trait.set_pure(node)
5456
has_done_something = True
5557
return RewriteResult(has_done_something=has_done_something)

test/analysis/dataflow/typeinfer/test_always_rewrite.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)