Skip to content

Commit 139977b

Browse files
authored
fix aggressive rules (#123)
`ConstFold` should `Walk`
1 parent df59814 commit 139977b

File tree

3 files changed

+12
-30
lines changed

3 files changed

+12
-30
lines changed

src/kirin/ir/dialect.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,7 @@ def __post_init__(self) -> None:
3434
self.lowering["default"] = NoSpecialLowering()
3535

3636
def __repr__(self) -> str:
37-
stmts = ", ".join([stmt.__name__ for stmt in self.stmts])
38-
attrs = ", ".join([attr.__name__ for attr in self.attrs])
39-
interps = ", ".join(
40-
[f"{key} = {type(interp).__name__}" for key, interp in self.interps.items()]
41-
)
42-
lowering = ", ".join(
43-
[f"{key} = {type(lower).__name__}" for key, lower in self.lowering.items()]
44-
)
45-
codegen = ", ".join(
46-
[f"{key} = {type(emit).__name__}" for key, emit in self.codegen.items()]
47-
)
48-
return f"""Dialect(\
49-
name={self.name},\
50-
stmts=[{stmts}], \
51-
attrs=[{attrs}], \
52-
interps=[{interps}], \
53-
lowering=[{lowering}]\
54-
codegen=[{codegen}]\
55-
)"""
37+
return f"Dialect(name={self.name}, ...)"
5638

5739
def __hash__(self) -> int:
5840
return hash(self.name)

src/kirin/rules/aggressive/fold.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def __init__(self, cfg: analysis.CFG, results: dict[ir.SSAValue, Any]):
2222
Chain(
2323
[
2424
Walk(Inline(lambda _: True)),
25-
ConstantFold(results),
26-
Call2Invoke(results),
25+
Walk(ConstantFold(results)),
26+
Walk(Call2Invoke(results)),
2727
Fixpoint(
2828
Walk(
2929
Chain(

test/program/py/test_aggressive.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def main():
3232
return x()
3333

3434

35-
@basic.add(dialect)
36-
def target():
37-
x = DummyStmt2(1, option="hello")
38-
DummyStmt2(x, option="hello")
39-
DummyStmt2(x, option="hello")
40-
return
41-
42-
4335
def test_aggressive_pass():
44-
assert target.callable_region.is_structurally_equal(main.callable_region)
36+
const_count = 0
37+
dummy_count = 0
38+
for stmt in main.callable_region.walk():
39+
if isinstance(stmt, DummyStmt2):
40+
dummy_count += 1
41+
elif stmt.has_trait(ir.ConstantLike):
42+
const_count += 1
43+
assert dummy_count == 3
44+
assert const_count == 2

0 commit comments

Comments
 (0)