Skip to content

Commit 610d9f8

Browse files
authored
Fix memref (#76)
1 parent c6d16e3 commit 610d9f8

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __getitem__(self, idx: tuple) -> "MemRef":
108108
if idx is None:
109109
return expand_shape(self, (0,), loc=loc)
110110

111-
idx = list((idx,) if isinstance(idx, int) else idx)
111+
idx = list((idx,) if isinstance(idx, (int, slice)) else idx)
112112
for i, d in enumerate(idx):
113113
if isinstance(d, int):
114114
idx[i] = constant(d, index=True, loc=loc)
@@ -291,6 +291,9 @@ def global_(
291291
sym_name = _get_sym_name(
292292
previous_frame, check_func_call="memref\\.global_|global_"
293293
)
294+
assert (
295+
sym_name is not None
296+
), "couldn't automatically find sym_name in previous frame"
294297
if loc is None:
295298
loc = get_user_code_loc()
296299
if initial_value is None:

tests/test_memref.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,24 @@ def test_simple_literal_indexing(ctx: MLIRContext):
5656
filecheck(correct, ctx.module)
5757

5858

59+
def test_simple_slicing(ctx: MLIRContext):
60+
mem = alloc(10, T.i32())
61+
62+
w = mem[5:]
63+
w = mem[:5]
64+
65+
correct = dedent(
66+
"""\
67+
module {
68+
%alloc = memref.alloc() : memref<10xi32>
69+
%subview = memref.subview %alloc[5] [5] [1] : memref<10xi32> to memref<5xi32, strided<[1], offset: 5>>
70+
%subview_0 = memref.subview %alloc[0] [5] [1] : memref<10xi32> to memref<5xi32>
71+
}
72+
"""
73+
)
74+
filecheck(correct, ctx.module)
75+
76+
5977
def test_simple_literal_indexing_alloca(ctx: MLIRContext):
6078
@alloca_scope([])
6179
def demo_scope2():

0 commit comments

Comments
 (0)