diff --git a/src/kirin/dialects/ilist/rewrite/const.py b/src/kirin/dialects/ilist/rewrite/const.py index 12b95614e..dedf58181 100644 --- a/src/kirin/dialects/ilist/rewrite/const.py +++ b/src/kirin/dialects/ilist/rewrite/const.py @@ -27,11 +27,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: typ = result.type data = hint.data if isinstance(typ, types.PyClass) and typ.is_subseteq(types.PyClass(IList)): - has_done_something = self._rewrite_IList_type(result, data) + has_done_something = has_done_something or self._rewrite_IList_type( + result, data + ) elif isinstance(typ, types.Generic) and typ.body.is_subseteq( types.PyClass(IList) ): - has_done_something = self._rewrite_IList_type(result, data) + has_done_something = has_done_something or self._rewrite_IList_type( + result, data + ) return RewriteResult(has_done_something=has_done_something) def rewrite_Constant(self, node: Constant) -> RewriteResult: @@ -53,6 +57,13 @@ def _rewrite_IList_type(self, result: ir.SSAValue, data): for elem in data.data[1:]: elem_type = elem_type.join(types.PyClass(type(elem))) - result.type = IListType[elem_type, types.Literal(len(data.data))] - result.hints["const"] = const.Value(data) + new_type = IListType[elem_type, types.Literal(len(data.data))] + new_hint = const.Value(data) + + # Check if type and hint are already correct + if result.type == new_type and result.hints.get("const") == new_hint: + return False + + result.type = new_type + result.hints["const"] = new_hint return True diff --git a/src/kirin/dialects/ilist/rewrite/hint_len.py b/src/kirin/dialects/ilist/rewrite/hint_len.py index d0d0d5d26..b20229ec7 100644 --- a/src/kirin/dialects/ilist/rewrite/hint_len.py +++ b/src/kirin/dialects/ilist/rewrite/hint_len.py @@ -32,6 +32,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if (coll_len := self._get_collection_len(node.value)) is None: return RewriteResult() - node.result.hints["const"] = const.Value(coll_len) + existing_hint = node.result.hints.get("const") + new_hint = const.Value(coll_len) + if existing_hint is not None and new_hint.is_structurally_equal(existing_hint): + return RewriteResult() + + node.result.hints["const"] = new_hint return RewriteResult(has_done_something=True) diff --git a/src/kirin/dialects/ilist/rewrite/list.py b/src/kirin/dialects/ilist/rewrite/list.py index 1b9963c98..4de60beef 100644 --- a/src/kirin/dialects/ilist/rewrite/list.py +++ b/src/kirin/dialects/ilist/rewrite/list.py @@ -13,7 +13,7 @@ class List2IList(RewriteRule): def rewrite_Block(self, node: ir.Block) -> RewriteResult: has_done_something = False for arg in node.args: - has_done_something = self._rewrite_SSAValue_type(arg) + has_done_something = has_done_something or self._rewrite_SSAValue_type(arg) return RewriteResult(has_done_something=has_done_something) def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @@ -25,7 +25,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: ) for result in node.results: - has_done_something = self._rewrite_SSAValue_type(result) + has_done_something = has_done_something or self._rewrite_SSAValue_type( + result + ) return RewriteResult(has_done_something=has_done_something) diff --git a/src/kirin/passes/abc.py b/src/kirin/passes/abc.py index a43a28787..40309de9e 100644 --- a/src/kirin/passes/abc.py +++ b/src/kirin/passes/abc.py @@ -36,7 +36,7 @@ def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult: for _ in range(max_iter): result_ = self.unsafe_run(mt) result = result_.join(result) - if not result.has_done_something: + if not result_.has_done_something: break mt.verify() return result diff --git a/src/kirin/rewrite/apply_type.py b/src/kirin/rewrite/apply_type.py index 1be7c3ada..9004da021 100644 --- a/src/kirin/rewrite/apply_type.py +++ b/src/kirin/rewrite/apply_type.py @@ -13,8 +13,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult: has_done_something = False for arg in node.args: if arg in self.results: - arg.type = self.results[arg] - has_done_something = True + arg_type = self.results[arg] + if arg.type != arg_type: + arg.type = arg_type + has_done_something = True return RewriteResult(has_done_something=has_done_something) @@ -22,8 +24,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: has_done_something = False for result in node._results: if result in self.results: - result.type = self.results[result] - has_done_something = True + arg_type = self.results[result] + if result.type != arg_type: + result.type = arg_type + has_done_something = True if (trait := node.get_trait(ir.HasSignature)) is not None and ( callable_trait := node.get_trait(ir.CallableStmtInterface) diff --git a/src/kirin/rewrite/cse.py b/src/kirin/rewrite/cse.py index 1740bcf54..d154eaf44 100644 --- a/src/kirin/rewrite/cse.py +++ b/src/kirin/rewrite/cse.py @@ -61,6 +61,7 @@ class CommonSubexpressionElimination(RewriteRule): def rewrite_Block(self, node: ir.Block) -> RewriteResult: seen: dict[Info, ir.Statement] = {} + has_done_something = False for stmt in node.stmts: if not stmt.has_trait(ir.Pure): @@ -81,10 +82,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult: for result, old_result in zip(stmt._results, old_stmt.results): result.replace_by(old_result) stmt.delete() - return RewriteResult(has_done_something=True) + has_done_something = True else: seen[info] = stmt - return RewriteResult() + return RewriteResult(has_done_something=has_done_something) def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if not node.regions: diff --git a/src/kirin/rewrite/fold.py b/src/kirin/rewrite/fold.py index e2d9d65d8..b7a74eb7f 100644 --- a/src/kirin/rewrite/fold.py +++ b/src/kirin/rewrite/fold.py @@ -29,6 +29,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: has_done_something = False for old_result in node.results: if (value := self.get_const(old_result)) is not None: + if not old_result.uses: + continue stmt = Constant(value.data) stmt.insert_before(node) old_result.replace_by(stmt.result) diff --git a/src/kirin/rewrite/walk.py b/src/kirin/rewrite/walk.py index 4d9a23638..9b8b7b51c 100644 --- a/src/kirin/rewrite/walk.py +++ b/src/kirin/rewrite/walk.py @@ -31,6 +31,8 @@ def rewrite(self, node: IRNode) -> RewriteResult: # NOTE: because the rewrite pass may mutate the node # thus we need to save the list of nodes to be processed # first before we start processing them + assert self.worklist.is_empty() + self.populate_worklist(node) has_done_something = False subnode = self.worklist.pop() diff --git a/src/kirin/rewrite/wrap_const.py b/src/kirin/rewrite/wrap_const.py index 06a40a335..93d5a189d 100644 --- a/src/kirin/rewrite/wrap_const.py +++ b/src/kirin/rewrite/wrap_const.py @@ -48,8 +48,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: has_done_something = True if ( - trait := node.get_trait(ir.MaybePure) - ) and node in self.frame.should_be_pure: + (trait := node.get_trait(ir.MaybePure)) + and node in self.frame.should_be_pure + and not trait.is_pure(node) + ): trait.set_pure(node) has_done_something = True return RewriteResult(has_done_something=has_done_something) diff --git a/test/analysis/dataflow/typeinfer/test_always_rewrite.py b/test/analysis/dataflow/typeinfer/test_always_rewrite.py deleted file mode 100644 index 72c0822e1..000000000 --- a/test/analysis/dataflow/typeinfer/test_always_rewrite.py +++ /dev/null @@ -1,18 +0,0 @@ -from kirin.passes import TypeInfer -from kirin.prelude import basic_no_opt - - -def test_always_rewrites(): - @basic_no_opt - def unstable(x: int): # type: ignore - y = x + 1 - if y > 10: - z = y - else: - z = y + 1.2 - return z - - result = TypeInfer(dialects=unstable.dialects, no_raise=False).fixpoint(unstable) - assert ( - result.has_done_something - ) # this will always be true because TypeInfer always rewrites type diff --git a/test/dialects/test_infer_len.py b/test/dialects/test_infer_len.py index 0576a1ac0..f76553657 100644 --- a/test/dialects/test_infer_len.py +++ b/test/dialects/test_infer_len.py @@ -7,7 +7,13 @@ def test(): rule = rewrite.Fixpoint( - rewrite.Walk(rewrite.Chain(ilist.rewrite.HintLen(), rewrite.ConstantFold())) + rewrite.Walk( + rewrite.Chain( + ilist.rewrite.HintLen(), + rewrite.ConstantFold(), + rewrite.DeadCodeElimination(), + ) + ) ) @basic @@ -24,6 +30,8 @@ def len_func3(xs: ilist.IList[int, Any]): stmt = len_func.callable_region.blocks[0].stmts.at(0) assert isinstance(stmt, py.Constant) assert stmt.value.unwrap() == 3 + assert len(len_func.callable_region.blocks[0].stmts) == 2 stmt = len_func3.callable_region.blocks[0].stmts.at(0) assert isinstance(stmt, py.Len) + assert len(len_func3.callable_region.blocks[0].stmts) == 2 diff --git a/test/testing/test_assert_structurally_same.py b/test/testing/test_assert_structurally_same.py index 1de3380e2..186f109f6 100644 --- a/test/testing/test_assert_structurally_same.py +++ b/test/testing/test_assert_structurally_same.py @@ -1,7 +1,7 @@ import pytest -from kirin.testing import assert_structurally_same from kirin.prelude import structural_no_opt +from kirin.testing import assert_structurally_same from kirin.dialects import py, func