Skip to content

Commit a1e276d

Browse files
committed
Add additional unit tests
1 parent d9ad5f5 commit a1e276d

File tree

2 files changed

+194
-16
lines changed

2 files changed

+194
-16
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pytest
2+
3+
from kirin import types
4+
from kirin.prelude import basic_no_opt
5+
from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst
6+
from kirin.analysis import const
7+
from kirin.dialects import ilist
8+
from kirin.rewrite.dce import DeadCodeElimination
9+
from kirin.dialects.py.indexing import GetItem
10+
from kirin.dialects.ilist.rewrite.inline_getitem import InlineGetItem
11+
12+
13+
def apply_getitem_optimization(func):
14+
constprop = const.Propagate(func.dialects)
15+
frame, _ = constprop.run(func)
16+
Fixpoint(Walk(WrapConst(frame))).rewrite(func.code)
17+
inline_getitem = InlineGetItem()
18+
Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code)
19+
20+
21+
@pytest.mark.parametrize("index", [0, -1, 1])
22+
def test_getitem_index(index):
23+
index = 0
24+
25+
@basic_no_opt
26+
def func(x: int):
27+
ylist = ilist.New(values=(x, x, 1, x), elem_type=types.PyClass(int))
28+
return ylist[index]
29+
30+
before = func(1)
31+
apply_getitem_optimization(func)
32+
after = func(1)
33+
34+
assert before == after
35+
assert len(func.callable_region.blocks[0].stmts) == 1
36+
print(func.code.print())
37+
38+
39+
@pytest.mark.parametrize(
40+
"sl",
41+
[
42+
slice(0, 2, 1),
43+
slice(None, None, None),
44+
slice(None, -1, None),
45+
slice(-1, None, None),
46+
slice(None, None, -1),
47+
slice(1, 4, 2),
48+
],
49+
)
50+
def test_getitem_slice(sl):
51+
52+
@basic_no_opt
53+
def func():
54+
ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int))
55+
return ylist[sl]
56+
57+
func.code.print()
58+
59+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
60+
assert GetItem in stmt_types
61+
62+
before = func()
63+
apply_getitem_optimization(func)
64+
after = func()
65+
66+
assert before == after
67+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
68+
assert GetItem not in stmt_types
69+
70+
71+
@pytest.mark.parametrize(
72+
"start, stop, step",
73+
[
74+
(0, 2, 1),
75+
(None, None, None),
76+
(None, -1, None),
77+
(-1, None, None),
78+
(None, None, -1),
79+
(1, 4, 2),
80+
],
81+
)
82+
def test_getitem_slice_with_literal_indices(start, stop, step):
83+
84+
@basic_no_opt
85+
def func():
86+
ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int))
87+
return ylist[start:stop:step]
88+
89+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
90+
assert GetItem in stmt_types
91+
92+
before = func()
93+
func.code.print()
94+
95+
apply_getitem_optimization(func)
96+
97+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
98+
assert GetItem not in stmt_types
99+
after = func()
100+
101+
func.code.print()
102+
103+
assert before == after

test/rules/test_getitem.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,101 @@
1+
import pytest
2+
13
from kirin.prelude import basic_no_opt
24
from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst
35
from kirin.analysis import const
46
from kirin.rewrite.dce import DeadCodeElimination
57
from kirin.rewrite.getitem import InlineGetItem
8+
from kirin.dialects.py.indexing import GetItem
69

710

8-
@basic_no_opt
9-
def main_simplify_getitem(x: int):
10-
ylist = (x, x, 1, 2)
11-
return ylist[0]
11+
def apply_getitem_optimization(func):
12+
constprop = const.Propagate(func.dialects)
13+
frame, _ = constprop.run(func)
14+
Fixpoint(Walk(WrapConst(frame))).rewrite(func.code)
15+
inline_getitem = InlineGetItem()
16+
print(func.code.print())
17+
Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code)
1218

1319

14-
def test_getitem():
15-
before = main_simplify_getitem(1)
16-
constprop = const.Propagate(main_simplify_getitem.dialects)
17-
frame, _ = constprop.run(main_simplify_getitem)
18-
Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_getitem.code)
19-
inline_getitem = InlineGetItem()
20-
Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(
21-
main_simplify_getitem.code
22-
)
23-
main_simplify_getitem.code.print()
24-
after = main_simplify_getitem(1)
20+
@pytest.mark.parametrize("index", [0, -1, 1])
21+
def test_getitem_index(index):
22+
23+
@basic_no_opt
24+
def func(x: int):
25+
ylist = (x, x, 1, x)
26+
return ylist[index]
27+
28+
before = func(1)
29+
apply_getitem_optimization(func)
30+
after = func(1)
31+
32+
assert before == after
33+
assert len(func.callable_region.blocks[0].stmts) == 1
34+
print(func.code.print())
35+
36+
37+
@pytest.mark.parametrize(
38+
"sl",
39+
[
40+
slice(0, 2, 1),
41+
slice(None, None, None),
42+
slice(None, -1, None),
43+
slice(-1, None, None),
44+
slice(None, None, -1),
45+
slice(1, 4, 2),
46+
],
47+
)
48+
def test_getitem_slice(sl):
49+
50+
@basic_no_opt
51+
def func():
52+
ylist = (0, 1, 2, 3, 4)
53+
return ylist[sl]
54+
55+
func.code.print()
56+
57+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
58+
assert GetItem in stmt_types
59+
60+
before = func()
61+
apply_getitem_optimization(func)
62+
after = func()
63+
64+
assert before == after
65+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
66+
assert GetItem not in stmt_types
67+
68+
69+
@pytest.mark.parametrize(
70+
"start, stop, step",
71+
[
72+
(0, 2, 1),
73+
(None, None, None),
74+
(None, -1, None),
75+
(-1, None, None),
76+
(None, None, -1),
77+
(1, 4, 2),
78+
],
79+
)
80+
def test_getitem_slice_with_literal_indices(start, stop, step):
81+
82+
@basic_no_opt
83+
def func():
84+
ylist = (0, 1, 2, 3, 4)
85+
return ylist[start:stop:step]
86+
87+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
88+
assert GetItem in stmt_types
89+
90+
before = func()
91+
func.code.print()
92+
93+
apply_getitem_optimization(func)
94+
95+
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
96+
assert GetItem not in stmt_types
97+
after = func()
98+
99+
func.code.print()
100+
25101
assert before == after
26-
assert len(main_simplify_getitem.callable_region.blocks[0].stmts) == 1

0 commit comments

Comments
 (0)