|
| 1 | +import pytest |
| 2 | + |
1 | 3 | from kirin.prelude import basic_no_opt |
2 | 4 | from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst |
3 | 5 | from kirin.analysis import const |
4 | 6 | from kirin.rewrite.dce import DeadCodeElimination |
5 | 7 | from kirin.rewrite.getitem import InlineGetItem |
| 8 | +from kirin.dialects.py.indexing import GetItem |
6 | 9 |
|
7 | 10 |
|
8 | | -@basic_no_opt |
9 | | -def main_simplify_getitem(x: int): |
10 | | - ylist = (x, x, 1, 2) |
11 | | - return ylist[0] |
| 11 | +def apply_getitem_optimization(func): |
| 12 | + constprop = const.Propagate(func.dialects) |
| 13 | + frame, _ = constprop.run(func) |
| 14 | + Fixpoint(Walk(WrapConst(frame))).rewrite(func.code) |
| 15 | + inline_getitem = InlineGetItem() |
| 16 | + print(func.code.print()) |
| 17 | + Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite(func.code) |
12 | 18 |
|
13 | 19 |
|
14 | | -def test_getitem(): |
15 | | - before = main_simplify_getitem(1) |
16 | | - constprop = const.Propagate(main_simplify_getitem.dialects) |
17 | | - frame, _ = constprop.run(main_simplify_getitem) |
18 | | - Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_getitem.code) |
19 | | - inline_getitem = InlineGetItem() |
20 | | - Fixpoint(Walk(Chain([inline_getitem, DeadCodeElimination()]))).rewrite( |
21 | | - main_simplify_getitem.code |
22 | | - ) |
23 | | - main_simplify_getitem.code.print() |
24 | | - after = main_simplify_getitem(1) |
| 20 | +@pytest.mark.parametrize("index", [0, -1, 1]) |
| 21 | +def test_getitem_index(index): |
| 22 | + |
| 23 | + @basic_no_opt |
| 24 | + def func(x: int): |
| 25 | + ylist = (x, x, 1, x) |
| 26 | + return ylist[index] |
| 27 | + |
| 28 | + before = func(1) |
| 29 | + apply_getitem_optimization(func) |
| 30 | + after = func(1) |
| 31 | + |
| 32 | + assert before == after |
| 33 | + assert len(func.callable_region.blocks[0].stmts) == 1 |
| 34 | + print(func.code.print()) |
| 35 | + |
| 36 | + |
| 37 | +@pytest.mark.parametrize( |
| 38 | + "sl", |
| 39 | + [ |
| 40 | + slice(0, 2, 1), |
| 41 | + slice(None, None, None), |
| 42 | + slice(None, -1, None), |
| 43 | + slice(-1, None, None), |
| 44 | + slice(None, None, -1), |
| 45 | + slice(1, 4, 2), |
| 46 | + ], |
| 47 | +) |
| 48 | +def test_getitem_slice(sl): |
| 49 | + |
| 50 | + @basic_no_opt |
| 51 | + def func(): |
| 52 | + ylist = (0, 1, 2, 3, 4) |
| 53 | + return ylist[sl] |
| 54 | + |
| 55 | + func.code.print() |
| 56 | + |
| 57 | + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] |
| 58 | + assert GetItem in stmt_types |
| 59 | + |
| 60 | + before = func() |
| 61 | + apply_getitem_optimization(func) |
| 62 | + after = func() |
| 63 | + |
| 64 | + assert before == after |
| 65 | + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] |
| 66 | + assert GetItem not in stmt_types |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.parametrize( |
| 70 | + "start, stop, step", |
| 71 | + [ |
| 72 | + (0, 2, 1), |
| 73 | + (None, None, None), |
| 74 | + (None, -1, None), |
| 75 | + (-1, None, None), |
| 76 | + (None, None, -1), |
| 77 | + (1, 4, 2), |
| 78 | + ], |
| 79 | +) |
| 80 | +def test_getitem_slice_with_literal_indices(start, stop, step): |
| 81 | + |
| 82 | + @basic_no_opt |
| 83 | + def func(): |
| 84 | + ylist = (0, 1, 2, 3, 4) |
| 85 | + return ylist[start:stop:step] |
| 86 | + |
| 87 | + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] |
| 88 | + assert GetItem in stmt_types |
| 89 | + |
| 90 | + before = func() |
| 91 | + func.code.print() |
| 92 | + |
| 93 | + apply_getitem_optimization(func) |
| 94 | + |
| 95 | + stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts] |
| 96 | + assert GetItem not in stmt_types |
| 97 | + after = func() |
| 98 | + |
| 99 | + func.code.print() |
| 100 | + |
25 | 101 | assert before == after |
26 | | - assert len(main_simplify_getitem.callable_region.blocks[0].stmts) == 1 |
|
0 commit comments