Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/kirin/dialects/ilist/rewrite/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
typ = result.type
data = hint.data
if isinstance(typ, types.PyClass) and typ.is_subseteq(types.PyClass(IList)):
has_done_something = self._rewrite_IList_type(result, data)
has_done_something = has_done_something or self._rewrite_IList_type(
result, data
)
elif isinstance(typ, types.Generic) and typ.body.is_subseteq(
types.PyClass(IList)
):
has_done_something = self._rewrite_IList_type(result, data)
has_done_something = has_done_something or self._rewrite_IList_type(
result, data
)
return RewriteResult(has_done_something=has_done_something)

def rewrite_Constant(self, node: Constant) -> RewriteResult:
Expand All @@ -53,6 +57,13 @@ def _rewrite_IList_type(self, result: ir.SSAValue, data):
for elem in data.data[1:]:
elem_type = elem_type.join(types.PyClass(type(elem)))

result.type = IListType[elem_type, types.Literal(len(data.data))]
result.hints["const"] = const.Value(data)
new_type = IListType[elem_type, types.Literal(len(data.data))]
new_hint = const.Value(data)

# Check if type and hint are already correct
if result.type == new_type and result.hints.get("const") == new_hint:
return False

result.type = new_type
result.hints["const"] = new_hint
return True
9 changes: 8 additions & 1 deletion src/kirin/dialects/ilist/rewrite/hint_len.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

from kirin import ir, types
from kirin.analysis import const
from kirin.dialects import py
Expand Down Expand Up @@ -32,6 +34,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if (coll_len := self._get_collection_len(node.value)) is None:
return RewriteResult()

node.result.hints["const"] = const.Value(coll_len)
existing_hint = cast(const.Result | None, node.result.hints.get("const"))
new_hint = const.Value(coll_len)

if existing_hint is not None and new_hint.is_equal(existing_hint):
return RewriteResult()

node.result.hints["const"] = new_hint
return RewriteResult(has_done_something=True)
6 changes: 4 additions & 2 deletions src/kirin/dialects/ilist/rewrite/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class List2IList(RewriteRule):
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
has_done_something = False
for arg in node.args:
has_done_something = self._rewrite_SSAValue_type(arg)
has_done_something = has_done_something or self._rewrite_SSAValue_type(arg)
return RewriteResult(has_done_something=has_done_something)

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
Expand All @@ -25,7 +25,9 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
)

for result in node.results:
has_done_something = self._rewrite_SSAValue_type(result)
has_done_something = has_done_something or self._rewrite_SSAValue_type(
result
)

return RewriteResult(has_done_something=has_done_something)

Expand Down
2 changes: 1 addition & 1 deletion src/kirin/passes/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
for _ in range(max_iter):
result_ = self.unsafe_run(mt)
result = result_.join(result)
if not result.has_done_something:
if not result_.has_done_something:
break
mt.verify()
return result
Expand Down
12 changes: 8 additions & 4 deletions src/kirin/rewrite/apply_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
has_done_something = False
for arg in node.args:
if arg in self.results:
arg.type = self.results[arg]
has_done_something = True
arg_type = self.results[arg]
if arg.type != arg_type:
arg.type = arg_type
has_done_something = True

return RewriteResult(has_done_something=has_done_something)

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
has_done_something = False
for result in node._results:
if result in self.results:
result.type = self.results[result]
has_done_something = True
arg_type = self.results[result]
if result.type != arg_type:
result.type = arg_type
has_done_something = True

if (trait := node.get_trait(ir.HasSignature)) is not None and (
callable_trait := node.get_trait(ir.CallableStmtInterface)
Expand Down
5 changes: 3 additions & 2 deletions src/kirin/rewrite/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class CommonSubexpressionElimination(RewriteRule):

def rewrite_Block(self, node: ir.Block) -> RewriteResult:
seen: dict[Info, ir.Statement] = {}
has_done_something = False

for stmt in node.stmts:
if not stmt.has_trait(ir.Pure):
Expand All @@ -81,10 +82,10 @@ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
for result, old_result in zip(stmt._results, old_stmt.results):
result.replace_by(old_result)
stmt.delete()
return RewriteResult(has_done_something=True)
has_done_something = True
else:
seen[info] = stmt
return RewriteResult()
return RewriteResult(has_done_something=has_done_something)

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not node.regions:
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/rewrite/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
has_done_something = False
for old_result in node.results:
if (value := self.get_const(old_result)) is not None:
if not old_result.uses:
continue
stmt = Constant(value.data)
stmt.insert_before(node)
old_result.replace_by(stmt.result)
Expand Down
2 changes: 2 additions & 0 deletions src/kirin/rewrite/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def rewrite(self, node: IRNode) -> RewriteResult:
# NOTE: because the rewrite pass may mutate the node
# thus we need to save the list of nodes to be processed
# first before we start processing them
assert self.worklist.is_empty()

self.populate_worklist(node)
has_done_something = False
subnode = self.worklist.pop()
Expand Down
6 changes: 4 additions & 2 deletions src/kirin/rewrite/wrap_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
has_done_something = True

if (
trait := node.get_trait(ir.MaybePure)
) and node in self.frame.should_be_pure:
(trait := node.get_trait(ir.MaybePure))
and node in self.frame.should_be_pure
and not trait.is_pure(node)
):
trait.set_pure(node)
has_done_something = True
return RewriteResult(has_done_something=has_done_something)
18 changes: 0 additions & 18 deletions test/analysis/dataflow/typeinfer/test_always_rewrite.py

This file was deleted.

10 changes: 9 additions & 1 deletion test/dialects/test_infer_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

def test():
rule = rewrite.Fixpoint(
rewrite.Walk(rewrite.Chain(ilist.rewrite.HintLen(), rewrite.ConstantFold()))
rewrite.Walk(
rewrite.Chain(
ilist.rewrite.HintLen(),
rewrite.ConstantFold(),
rewrite.DeadCodeElimination(),
)
)
)

@basic
Expand All @@ -24,6 +30,8 @@ def len_func3(xs: ilist.IList[int, Any]):
stmt = len_func.callable_region.blocks[0].stmts.at(0)
assert isinstance(stmt, py.Constant)
assert stmt.value.unwrap() == 3
assert len(len_func.callable_region.blocks[0].stmts) == 2

stmt = len_func3.callable_region.blocks[0].stmts.at(0)
assert isinstance(stmt, py.Len)
assert len(len_func3.callable_region.blocks[0].stmts) == 2