Skip to content

Commit 5fd6d12

Browse files
authored
Two Fixes to constant prop, IList (#524)
changes in this PR: 1. Bug fixes in `ilist` dialect that gaurd against non-pure execution of `map` `for_each` `foldl`, etc. ~~2. Adding a try-except inside `try_eval_const_pure` in cases where the concrete interpreter is missing an implementation of a statement that is pure, falling back to returning `Unknown` for all results.~~
1 parent d53e5e6 commit 5fd6d12

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

src/kirin/dialects/ilist/constprop.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def one_args(
3131
# 1. if the function is a constant method, and the method is pure, then the map is pure
3232
if isinstance(fn, const.Value) and isinstance(method := fn.data, ir.Method):
3333
self.detect_purity(interp_, frame, stmt, method.code, (fn, const.Unknown()))
34-
if isinstance(collection, const.Value):
34+
if isinstance(collection, const.Value) and stmt in frame.should_be_pure:
3535
return interp_.try_eval_const_pure(frame, stmt, (fn, collection))
3636
elif isinstance(fn, const.PartialLambda):
3737
self.detect_purity(interp_, frame, stmt, fn.code, (fn, const.Unknown()))
@@ -57,7 +57,11 @@ def two_args(self, interp_: const.Propagate, frame: const.Frame, stmt: Foldl):
5757
method.code,
5858
(fn, const.Unknown(), const.Unknown()),
5959
)
60-
if isinstance(collection, const.Value) and isinstance(init, const.Value):
60+
if (
61+
isinstance(collection, const.Value)
62+
and isinstance(init, const.Value)
63+
and stmt in frame.should_be_pure
64+
):
6165
return interp_.try_eval_const_pure(frame, stmt, (fn, collection, init))
6266
elif isinstance(fn, const.PartialLambda):
6367
self.detect_purity(

test/dialects/test_ilist.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import Any, Literal
22

33
from kirin import ir, types, rewrite
4+
from kirin.decl import info, statement
45
from kirin.passes import aggressive
56
from kirin.prelude import structural, basic_no_opt, python_basic
67
from kirin.analysis import const
78
from kirin.dialects import py, func, ilist, lowering
9+
from kirin.lowering import FromPythonCall
810
from kirin.passes.typeinfer import TypeInfer
911

1012

@@ -426,6 +428,31 @@ def main2():
426428
assert target.data == (6, 6)
427429

428430

431+
def test_ilist_constprop_non_pure():
432+
433+
new_dialect = ir.Dialect("test")
434+
435+
@statement(dialect=new_dialect)
436+
class DefaultInit(ir.Statement):
437+
name = "test"
438+
traits = frozenset({FromPythonCall()})
439+
result: ir.ResultValue = info.result(types.Float)
440+
441+
dialect_group = basic_no_opt.add(new_dialect)
442+
443+
@dialect_group
444+
def test():
445+
446+
def inner(_: int):
447+
return DefaultInit()
448+
449+
return ilist.map(inner, ilist.range(10))
450+
451+
_, res = const.Propagate(dialect_group).run(test)
452+
453+
assert isinstance(res, const.Unknown)
454+
455+
429456
rule = rewrite.Fixpoint(rewrite.Walk(ilist.rewrite.Unroll()))
430457
xs = ilist.IList([1, 2, 3])
431458

0 commit comments

Comments
 (0)