diff --git a/src/kirin/dialects/ilist/rewrite/inline_getitem.py b/src/kirin/dialects/ilist/rewrite/inline_getitem.py index 9120f1648..a84612a44 100644 --- a/src/kirin/dialects/ilist/rewrite/inline_getitem.py +++ b/src/kirin/dialects/ilist/rewrite/inline_getitem.py @@ -32,10 +32,7 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: node.result.replace_by(stmt.args[index]) return abc.RewriteResult(has_done_something=True) elif isinstance(index, slice): - start, stop, step = index.indices(len(stmt.args)) - new_tuple = New( - tuple(stmt.args[start:stop:step]), - ) + new_tuple = New(tuple(stmt.args[index])) node.replace_by(new_tuple) return abc.RewriteResult(has_done_something=True) else: diff --git a/src/kirin/dialects/py/indexing.py b/src/kirin/dialects/py/indexing.py index a1d88243b..fab1ac038 100644 --- a/src/kirin/dialects/py/indexing.py +++ b/src/kirin/dialects/py/indexing.py @@ -214,8 +214,7 @@ def getitem( if isinstance(index.data, int) and 0 <= index.data < len(obj): return (obj[index.data],) elif isinstance(index.data, slice): - start, stop, step = index.data.indices(len(obj)) - return (const.PartialTuple(obj[start:stop:step]),) + return (const.PartialTuple(obj[index.data]),) return (const.Unknown(),) diff --git a/src/kirin/rewrite/getitem.py b/src/kirin/rewrite/getitem.py index 5cc91d024..47c10c130 100644 --- a/src/kirin/rewrite/getitem.py +++ b/src/kirin/rewrite/getitem.py @@ -27,10 +27,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: node.result.replace_by(stmt.args[index]) return RewriteResult(has_done_something=True) elif isinstance(index, slice): - start, stop, step = index.indices(len(stmt.args)) - new_tuple = py.tuple.New( - tuple(stmt.args[start:stop:step]), - ) + new_tuple = py.tuple.New(tuple(stmt.args[index])) node.replace_by(new_tuple) return RewriteResult(has_done_something=True) else: diff --git a/test/dialects/ilist/test_inline_getitem.py b/test/dialects/ilist/test_inline_getitem.py new file mode 100644 index 000000000..119a87e82 --- /dev/null +++ b/test/dialects/ilist/test_inline_getitem.py @@ -0,0 +1,97 @@ +import pytest + +from kirin import types +from kirin.prelude import basic_no_opt +from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst +from kirin.analysis import const +from kirin.dialects import ilist +from kirin.rewrite.dce import DeadCodeElimination +from kirin.dialects.py.indexing import GetItem +from kirin.dialects.ilist.rewrite.inline_getitem import InlineGetItem + + +def apply_getitem_optimization(func): + constprop = const.Propagate(func.dialects) + frame, _ = constprop.run(func) + Fixpoint(Walk(WrapConst(frame))).rewrite(func.code) + inline_getitem = InlineGetItem() + Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code) + + +@pytest.mark.parametrize("index", [0, -1, 1]) +def test_getitem_index(index): + index = 0 + + @basic_no_opt + def func(x: int): + ylist = ilist.New(values=(x, x, 1, x), elem_type=types.PyClass(int)) + return ylist[index] + + before = func(1) + apply_getitem_optimization(func) + after = func(1) + + assert before == after + assert len(func.callable_region.blocks[0].stmts) == 1 + + +@pytest.mark.parametrize( + "sl", + [ + slice(0, 2, 1), + slice(None, None, None), + slice(None, -1, None), + slice(-1, None, None), + slice(None, None, -1), + slice(1, 4, 2), + ], +) +def test_getitem_slice(sl): + + @basic_no_opt + def func(): + ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int)) + return ylist[sl] + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + apply_getitem_optimization(func) + after = func() + + assert before == after + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + + +@pytest.mark.parametrize( + "start, stop, step", + [ + (0, 2, 1), + (None, None, None), + (None, -1, None), + (-1, None, None), + (None, None, -1), + (1, 4, 2), + ], +) +def test_getitem_slice_with_literal_indices(start, stop, step): + + @basic_no_opt + def func(): + ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int)) + return ylist[start:stop:step] + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + + apply_getitem_optimization(func) + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + after = func() + + assert before == after diff --git a/test/rules/test_getitem.py b/test/rules/test_getitem.py index b4a943c76..76e7633de 100644 --- a/test/rules/test_getitem.py +++ b/test/rules/test_getitem.py @@ -1,26 +1,94 @@ +import pytest + from kirin.prelude import basic_no_opt from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst from kirin.analysis import const from kirin.rewrite.dce import DeadCodeElimination from kirin.rewrite.getitem import InlineGetItem +from kirin.dialects.py.indexing import GetItem -@basic_no_opt -def main_simplify_getitem(x: int): - ylist = (x, x, 1, 2) - return ylist[0] +def apply_getitem_optimization(func): + constprop = const.Propagate(func.dialects) + frame, _ = constprop.run(func) + Fixpoint(Walk(WrapConst(frame))).rewrite(func.code) + inline_getitem = InlineGetItem() + Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code) -def test_getitem(): - before = main_simplify_getitem(1) - constprop = const.Propagate(main_simplify_getitem.dialects) - frame, _ = constprop.run(main_simplify_getitem) - Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_getitem.code) - inline_getitem = InlineGetItem() - Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite( - main_simplify_getitem.code - ) - main_simplify_getitem.code.print() - after = main_simplify_getitem(1) +@pytest.mark.parametrize("index", [0, -1, 1]) +def test_getitem_index(index): + + @basic_no_opt + def func(x: int): + ylist = (x, x, 1, x) + return ylist[index] + + before = func(1) + apply_getitem_optimization(func) + after = func(1) + + assert before == after + assert len(func.callable_region.blocks[0].stmts) == 1 + + +@pytest.mark.parametrize( + "sl", + [ + slice(0, 2, 1), + slice(None, None, None), + slice(None, -1, None), + slice(-1, None, None), + slice(None, None, -1), + slice(1, 4, 2), + ], +) +def test_getitem_slice(sl): + + @basic_no_opt + def func(): + ylist = (0, 1, 2, 3, 4) + return ylist[sl] + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + apply_getitem_optimization(func) + after = func() + + assert before == after + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + + +@pytest.mark.parametrize( + "start, stop, step", + [ + (0, 2, 1), + (None, None, None), + (None, -1, None), + (-1, None, None), + (None, None, -1), + (1, 4, 2), + ], +) +def test_getitem_slice_with_literal_indices(start, stop, step): + + @basic_no_opt + def func(): + ylist = (0, 1, 2, 3, 4) + return ylist[start:stop:step] + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + + apply_getitem_optimization(func) + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + after = func() + assert before == after - assert len(main_simplify_getitem.callable_region.blocks[0].stmts) == 1