Skip to content

Commit 9f333b2

Browse files
authored
Adding inline getitem rewrite for ilist (#446)
This is currently missing, it is definitely nice to have.
1 parent 59c8d8b commit 9f333b2

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed

src/kirin/dialects/ilist/rewrite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .const import ConstList2IList as ConstList2IList
33
from .unroll import Unroll as Unroll
44
from .hint_len import HintLen as HintLen
5+
from .inline_getitem import InlineGetItem as InlineGetItem
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc
3+
from kirin.analysis import const
4+
from kirin.dialects import py
5+
6+
from ..stmts import New
7+
8+
9+
class InlineGetItem(abc.RewriteRule):
10+
"""Rewrite rule to inline GetItem statements for IList.
11+
12+
For example if we have an `ilist.New` statement with a list of items,
13+
and we can infer that the index used in `py.GetItem` is constant and within bounds,
14+
we replace the `py.GetItem` with the ssa value in the list when the index is an integer
15+
or with a new `ilist.New` statement containing the sliced items when the index is a slice.
16+
17+
"""
18+
19+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
20+
if not isinstance(node, py.GetItem) or not isinstance(
21+
stmt := node.obj.owner, New
22+
):
23+
return abc.RewriteResult()
24+
25+
if not isinstance(index_const := node.index.hints.get("const"), const.Value):
26+
return abc.RewriteResult()
27+
28+
index = index_const.data
29+
if isinstance(index, int) and (
30+
0 <= index < len(stmt.args) or -len(stmt.args) <= index < 0
31+
):
32+
node.result.replace_by(stmt.args[index])
33+
return abc.RewriteResult(has_done_something=True)
34+
elif isinstance(index, slice):
35+
start, stop, step = index.indices(len(stmt.args))
36+
new_tuple = New(
37+
tuple(stmt.args[start:stop:step]),
38+
)
39+
node.replace_by(new_tuple)
40+
return abc.RewriteResult(has_done_something=True)
41+
else:
42+
return abc.RewriteResult()

test/dialects/test_ilist.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from kirin import ir, types, rewrite
44
from kirin.passes import aggressive
55
from kirin.prelude import basic_no_opt, python_basic
6+
from kirin.analysis import const
67
from kirin.dialects import py, func, ilist, lowering
78
from kirin.passes.typeinfer import TypeInfer
89

@@ -163,6 +164,41 @@ def const_range():
163164
assert const_range() == ilist.IList(range(0, 3))
164165

165166

167+
def test_inline_get_item():
168+
items = tuple(ir.TestValue() for _ in range(2))
169+
170+
test_block = ir.Block(
171+
[
172+
qreg := ilist.New(values=items),
173+
idx := py.Constant(0),
174+
qubit_stmt := py.GetItem(obj=qreg.result, index=idx.result),
175+
ilist.New(values=(qubit_stmt.result,)),
176+
idx1 := py.Constant(10),
177+
qubit_stmt := py.GetItem(obj=qreg.result, index=idx1.result),
178+
ilist.New(values=(qubit_stmt.result,)),
179+
]
180+
)
181+
182+
idx.result.hints["const"] = const.Value(0)
183+
idx1.result.hints["const"] = const.Value(10)
184+
rule = rewrite.Walk(ilist.rewrite.InlineGetItem())
185+
rule.rewrite(test_block)
186+
187+
expected_block = ir.Block(
188+
[
189+
qreg := ilist.New(values=items),
190+
idx := py.Constant(0),
191+
qubit_stmt := py.GetItem(obj=qreg.result, index=idx.result),
192+
ilist.New(values=(items[0],)),
193+
idx1 := py.Constant(10),
194+
qubit_stmt := py.GetItem(obj=qreg.result, index=idx1.result),
195+
ilist.New(values=(qubit_stmt.result,)),
196+
]
197+
)
198+
199+
assert test_block.is_equal(expected_block)
200+
201+
166202
def test_ilist_constprop():
167203
from kirin.analysis import const
168204

0 commit comments

Comments
 (0)