Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/kirin/dialects/ilist/rewrite/inline_getitem.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is almost the same as getitem.py (list vs ilist)

Can code duplication be avoided?

Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
node.result.replace_by(stmt.args[index])
return abc.RewriteResult(has_done_something=True)
elif isinstance(index, slice):
start, stop, step = index.indices(len(stmt.args))
new_tuple = New(
tuple(stmt.args[start:stop:step]),
)
new_tuple = New(tuple(stmt.args[index]))
node.replace_by(new_tuple)
return abc.RewriteResult(has_done_something=True)
else:
Expand Down
3 changes: 1 addition & 2 deletions src/kirin/dialects/py/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def getitem(
if isinstance(index.data, int) and 0 <= index.data < len(obj):
return (obj[index.data],)
elif isinstance(index.data, slice):
start, stop, step = index.data.indices(len(obj))
return (const.PartialTuple(obj[start:stop:step]),)
return (const.PartialTuple(obj[index.data]),)
return (const.Unknown(),)


Expand Down
5 changes: 1 addition & 4 deletions src/kirin/rewrite/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
node.result.replace_by(stmt.args[index])
return RewriteResult(has_done_something=True)
elif isinstance(index, slice):
start, stop, step = index.indices(len(stmt.args))
new_tuple = py.tuple.New(
tuple(stmt.args[start:stop:step]),
)
new_tuple = py.tuple.New(tuple(stmt.args[index]))
node.replace_by(new_tuple)
return RewriteResult(has_done_something=True)
else:
Expand Down
97 changes: 97 additions & 0 deletions test/dialects/ilist/test_inline_getitem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest

from kirin import types
from kirin.prelude import basic_no_opt
from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst
from kirin.analysis import const
from kirin.dialects import ilist
from kirin.rewrite.dce import DeadCodeElimination
from kirin.dialects.py.indexing import GetItem
from kirin.dialects.ilist.rewrite.inline_getitem import InlineGetItem


def apply_getitem_optimization(func):
constprop = const.Propagate(func.dialects)
frame, _ = constprop.run(func)
Fixpoint(Walk(WrapConst(frame))).rewrite(func.code)
inline_getitem = InlineGetItem()
Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code)


@pytest.mark.parametrize("index", [0, -1, 1])
def test_getitem_index(index):
index = 0

@basic_no_opt
def func(x: int):
ylist = ilist.New(values=(x, x, 1, x), elem_type=types.PyClass(int))
return ylist[index]

before = func(1)
apply_getitem_optimization(func)
after = func(1)

assert before == after
assert len(func.callable_region.blocks[0].stmts) == 1


@pytest.mark.parametrize(
"sl",
[
slice(0, 2, 1),
slice(None, None, None),
slice(None, -1, None),
slice(-1, None, None),
slice(None, None, -1),
slice(1, 4, 2),
],
)
def test_getitem_slice(sl):

@basic_no_opt
def func():
ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int))
return ylist[sl]

stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem in stmt_types

before = func()
apply_getitem_optimization(func)
after = func()

assert before == after
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem not in stmt_types


@pytest.mark.parametrize(
"start, stop, step",
[
(0, 2, 1),
(None, None, None),
(None, -1, None),
(-1, None, None),
(None, None, -1),
(1, 4, 2),
],
)
def test_getitem_slice_with_literal_indices(start, stop, step):

@basic_no_opt
def func():
ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int))
return ylist[start:stop:step]

stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem in stmt_types

before = func()

apply_getitem_optimization(func)

stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem not in stmt_types
after = func()

assert before == after
100 changes: 84 additions & 16 deletions test/rules/test_getitem.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,94 @@
import pytest

from kirin.prelude import basic_no_opt
from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst
from kirin.analysis import const
from kirin.rewrite.dce import DeadCodeElimination
from kirin.rewrite.getitem import InlineGetItem
from kirin.dialects.py.indexing import GetItem


@basic_no_opt
def main_simplify_getitem(x: int):
ylist = (x, x, 1, 2)
return ylist[0]
def apply_getitem_optimization(func):
constprop = const.Propagate(func.dialects)
frame, _ = constprop.run(func)
Fixpoint(Walk(WrapConst(frame))).rewrite(func.code)
inline_getitem = InlineGetItem()
Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code)


def test_getitem():
before = main_simplify_getitem(1)
constprop = const.Propagate(main_simplify_getitem.dialects)
frame, _ = constprop.run(main_simplify_getitem)
Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_getitem.code)
inline_getitem = InlineGetItem()
Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(
main_simplify_getitem.code
)
main_simplify_getitem.code.print()
after = main_simplify_getitem(1)
@pytest.mark.parametrize("index", [0, -1, 1])
def test_getitem_index(index):

@basic_no_opt
def func(x: int):
ylist = (x, x, 1, x)
return ylist[index]

before = func(1)
apply_getitem_optimization(func)
after = func(1)

assert before == after
assert len(func.callable_region.blocks[0].stmts) == 1


@pytest.mark.parametrize(
"sl",
[
slice(0, 2, 1),
slice(None, None, None),
slice(None, -1, None),
slice(-1, None, None),
slice(None, None, -1),
slice(1, 4, 2),
],
)
def test_getitem_slice(sl):

@basic_no_opt
def func():
ylist = (0, 1, 2, 3, 4)
return ylist[sl]

stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem in stmt_types

before = func()
apply_getitem_optimization(func)
after = func()

assert before == after
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem not in stmt_types


@pytest.mark.parametrize(
"start, stop, step",
[
(0, 2, 1),
(None, None, None),
(None, -1, None),
(-1, None, None),
(None, None, -1),
(1, 4, 2),
],
)
def test_getitem_slice_with_literal_indices(start, stop, step):

@basic_no_opt
def func():
ylist = (0, 1, 2, 3, 4)
return ylist[start:stop:step]

stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem in stmt_types

before = func()

apply_getitem_optimization(func)

stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
assert GetItem not in stmt_types
after = func()

assert before == after
assert len(main_simplify_getitem.callable_region.blocks[0].stmts) == 1