From d9ad5f544901634ca313b884e9d8821c3083fdb8 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Fri, 31 Oct 2025 11:41:32 -0700 Subject: [PATCH 1/3] Fix bug in edge case with reverse slices like `my_list[::-1]` This is coming from unexpected behavior of `Slice.indices`. The solution is to simply avoid using `.indices`. A MRE is nums = [0, 1, 2, 3, 4] sl = slice(None, None, -1) start, stop, step = sl.indices(len(nums)) print(nums[sl]) # [4, 3, 2, 1, 0] print(nums[start:stop:step]) # [] --- src/kirin/dialects/ilist/rewrite/inline_getitem.py | 5 +---- src/kirin/dialects/py/indexing.py | 3 +-- src/kirin/rewrite/getitem.py | 5 +---- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/kirin/dialects/ilist/rewrite/inline_getitem.py b/src/kirin/dialects/ilist/rewrite/inline_getitem.py index 9120f1648..a84612a44 100644 --- a/src/kirin/dialects/ilist/rewrite/inline_getitem.py +++ b/src/kirin/dialects/ilist/rewrite/inline_getitem.py @@ -32,10 +32,7 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: node.result.replace_by(stmt.args[index]) return abc.RewriteResult(has_done_something=True) elif isinstance(index, slice): - start, stop, step = index.indices(len(stmt.args)) - new_tuple = New( - tuple(stmt.args[start:stop:step]), - ) + new_tuple = New(tuple(stmt.args[index])) node.replace_by(new_tuple) return abc.RewriteResult(has_done_something=True) else: diff --git a/src/kirin/dialects/py/indexing.py b/src/kirin/dialects/py/indexing.py index a1d88243b..fab1ac038 100644 --- a/src/kirin/dialects/py/indexing.py +++ b/src/kirin/dialects/py/indexing.py @@ -214,8 +214,7 @@ def getitem( if isinstance(index.data, int) and 0 <= index.data < len(obj): return (obj[index.data],) elif isinstance(index.data, slice): - start, stop, step = index.data.indices(len(obj)) - return (const.PartialTuple(obj[start:stop:step]),) + return (const.PartialTuple(obj[index.data]),) return (const.Unknown(),) diff --git a/src/kirin/rewrite/getitem.py b/src/kirin/rewrite/getitem.py index 5cc91d024..47c10c130 100644 --- a/src/kirin/rewrite/getitem.py +++ b/src/kirin/rewrite/getitem.py @@ -27,10 +27,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: node.result.replace_by(stmt.args[index]) return RewriteResult(has_done_something=True) elif isinstance(index, slice): - start, stop, step = index.indices(len(stmt.args)) - new_tuple = py.tuple.New( - tuple(stmt.args[start:stop:step]), - ) + new_tuple = py.tuple.New(tuple(stmt.args[index])) node.replace_by(new_tuple) return RewriteResult(has_done_something=True) else: From a1e276d233b87852ccbc4c3054dcc49c496ea7d8 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Fri, 31 Oct 2025 11:41:57 -0700 Subject: [PATCH 2/3] Add additional unit tests --- test/dialects/ilist/test_inline_getitem.py | 103 ++++++++++++++++++++ test/rules/test_getitem.py | 107 ++++++++++++++++++--- 2 files changed, 194 insertions(+), 16 deletions(-) create mode 100644 test/dialects/ilist/test_inline_getitem.py diff --git a/test/dialects/ilist/test_inline_getitem.py b/test/dialects/ilist/test_inline_getitem.py new file mode 100644 index 000000000..827ad9c2b --- /dev/null +++ b/test/dialects/ilist/test_inline_getitem.py @@ -0,0 +1,103 @@ +import pytest + +from kirin import types +from kirin.prelude import basic_no_opt +from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst +from kirin.analysis import const +from kirin.dialects import ilist +from kirin.rewrite.dce import DeadCodeElimination +from kirin.dialects.py.indexing import GetItem +from kirin.dialects.ilist.rewrite.inline_getitem import InlineGetItem + + +def apply_getitem_optimization(func): + constprop = const.Propagate(func.dialects) + frame, _ = constprop.run(func) + Fixpoint(Walk(WrapConst(frame))).rewrite(func.code) + inline_getitem = InlineGetItem() + Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code) + + +@pytest.mark.parametrize("index", [0, -1, 1]) +def test_getitem_index(index): + index = 0 + + @basic_no_opt + def func(x: int): + ylist = ilist.New(values=(x, x, 1, x), elem_type=types.PyClass(int)) + return ylist[index] + + before = func(1) + apply_getitem_optimization(func) + after = func(1) + + assert before == after + assert len(func.callable_region.blocks[0].stmts) == 1 + print(func.code.print()) + + +@pytest.mark.parametrize( + "sl", + [ + slice(0, 2, 1), + slice(None, None, None), + slice(None, -1, None), + slice(-1, None, None), + slice(None, None, -1), + slice(1, 4, 2), + ], +) +def test_getitem_slice(sl): + + @basic_no_opt + def func(): + ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int)) + return ylist[sl] + + func.code.print() + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + apply_getitem_optimization(func) + after = func() + + assert before == after + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + + +@pytest.mark.parametrize( + "start, stop, step", + [ + (0, 2, 1), + (None, None, None), + (None, -1, None), + (-1, None, None), + (None, None, -1), + (1, 4, 2), + ], +) +def test_getitem_slice_with_literal_indices(start, stop, step): + + @basic_no_opt + def func(): + ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int)) + return ylist[start:stop:step] + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + func.code.print() + + apply_getitem_optimization(func) + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + after = func() + + func.code.print() + + assert before == after diff --git a/test/rules/test_getitem.py b/test/rules/test_getitem.py index b4a943c76..7c27ad0f1 100644 --- a/test/rules/test_getitem.py +++ b/test/rules/test_getitem.py @@ -1,26 +1,101 @@ +import pytest + from kirin.prelude import basic_no_opt from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst from kirin.analysis import const from kirin.rewrite.dce import DeadCodeElimination from kirin.rewrite.getitem import InlineGetItem +from kirin.dialects.py.indexing import GetItem -@basic_no_opt -def main_simplify_getitem(x: int): - ylist = (x, x, 1, 2) - return ylist[0] +def apply_getitem_optimization(func): + constprop = const.Propagate(func.dialects) + frame, _ = constprop.run(func) + Fixpoint(Walk(WrapConst(frame))).rewrite(func.code) + inline_getitem = InlineGetItem() + print(func.code.print()) + Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code) -def test_getitem(): - before = main_simplify_getitem(1) - constprop = const.Propagate(main_simplify_getitem.dialects) - frame, _ = constprop.run(main_simplify_getitem) - Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_getitem.code) - inline_getitem = InlineGetItem() - Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite( - main_simplify_getitem.code - ) - main_simplify_getitem.code.print() - after = main_simplify_getitem(1) +@pytest.mark.parametrize("index", [0, -1, 1]) +def test_getitem_index(index): + + @basic_no_opt + def func(x: int): + ylist = (x, x, 1, x) + return ylist[index] + + before = func(1) + apply_getitem_optimization(func) + after = func(1) + + assert before == after + assert len(func.callable_region.blocks[0].stmts) == 1 + print(func.code.print()) + + +@pytest.mark.parametrize( + "sl", + [ + slice(0, 2, 1), + slice(None, None, None), + slice(None, -1, None), + slice(-1, None, None), + slice(None, None, -1), + slice(1, 4, 2), + ], +) +def test_getitem_slice(sl): + + @basic_no_opt + def func(): + ylist = (0, 1, 2, 3, 4) + return ylist[sl] + + func.code.print() + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + apply_getitem_optimization(func) + after = func() + + assert before == after + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + + +@pytest.mark.parametrize( + "start, stop, step", + [ + (0, 2, 1), + (None, None, None), + (None, -1, None), + (-1, None, None), + (None, None, -1), + (1, 4, 2), + ], +) +def test_getitem_slice_with_literal_indices(start, stop, step): + + @basic_no_opt + def func(): + ylist = (0, 1, 2, 3, 4) + return ylist[start:stop:step] + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem in stmt_types + + before = func() + func.code.print() + + apply_getitem_optimization(func) + + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] + assert GetItem not in stmt_types + after = func() + + func.code.print() + assert before == after - assert len(main_simplify_getitem.callable_region.blocks[0].stmts) == 1 From f61f805541d1c07df727e9364ccd0b5b12432d80 Mon Sep 17 00:00:00 2001 From: Rafael Haenel Date: Fri, 31 Oct 2025 11:46:38 -0700 Subject: [PATCH 3/3] Remove leftover print statements --- test/dialects/ilist/test_inline_getitem.py | 6 ------ test/rules/test_getitem.py | 7 ------- 2 files changed, 13 deletions(-) diff --git a/test/dialects/ilist/test_inline_getitem.py b/test/dialects/ilist/test_inline_getitem.py index 827ad9c2b..119a87e82 100644 --- a/test/dialects/ilist/test_inline_getitem.py +++ b/test/dialects/ilist/test_inline_getitem.py @@ -33,7 +33,6 @@ def func(x: int): assert before == after assert len(func.callable_region.blocks[0].stmts) == 1 - print(func.code.print()) @pytest.mark.parametrize( @@ -54,8 +53,6 @@ def func(): ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int)) return ylist[sl] - func.code.print() - stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] assert GetItem in stmt_types @@ -90,7 +87,6 @@ def func(): assert GetItem in stmt_types before = func() - func.code.print() apply_getitem_optimization(func) @@ -98,6 +94,4 @@ def func(): assert GetItem not in stmt_types after = func() - func.code.print() - assert before == after diff --git a/test/rules/test_getitem.py b/test/rules/test_getitem.py index 7c27ad0f1..76e7633de 100644 --- a/test/rules/test_getitem.py +++ b/test/rules/test_getitem.py @@ -13,7 +13,6 @@ def apply_getitem_optimization(func): frame, _ = constprop.run(func) Fixpoint(Walk(WrapConst(frame))).rewrite(func.code) inline_getitem = InlineGetItem() - print(func.code.print()) Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code) @@ -31,7 +30,6 @@ def func(x: int): assert before == after assert len(func.callable_region.blocks[0].stmts) == 1 - print(func.code.print()) @pytest.mark.parametrize( @@ -52,8 +50,6 @@ def func(): ylist = (0, 1, 2, 3, 4) return ylist[sl] - func.code.print() - stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] assert GetItem in stmt_types @@ -88,7 +84,6 @@ def func(): assert GetItem in stmt_types before = func() - func.code.print() apply_getitem_optimization(func) @@ -96,6 +91,4 @@ def func(): assert GetItem not in stmt_types after = func() - func.code.print() - assert before == after