Skip to content

Commit eb0bbca

Browse files
authored
add memref.view (#79)
1 parent 610d9f8 commit eb0bbca

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,12 @@ def global_(
318318
loc=loc,
319319
ip=ip,
320320
).opview
321+
322+
323+
def view(source, shape, dtype=None, shift=0):
324+
if dtype is None:
325+
dtype = source.type.element_type
326+
byte_width_dtype = dtype.width // 8
327+
byte_shift = shift * byte_width_dtype
328+
byte_shift = constant(byte_shift, index=True)
329+
return memref.view(T.memref(*shape, dtype), source, byte_shift, [])

tests/test_memref.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,25 @@ def test_memref_global_non_windows(ctx: MLIRContext):
661661
)
662662

663663
filecheck(correct, ctx.module)
664+
665+
666+
def test_memref_view(ctx: MLIRContext):
667+
m, k, n = 16, 16, 16
668+
dtype = T.f32()
669+
byte_width_dtype = dtype.width // 8
670+
ab_buffer = alloc((m * k + k * n) * byte_width_dtype, T.i8())
671+
a_buffer = memref.view(ab_buffer, (m, k), dtype=dtype)
672+
b_buffer = memref.view(ab_buffer, (k, n), dtype=dtype, shift=m * k)
673+
674+
correct = dedent(
675+
"""\
676+
module {
677+
%alloc = memref.alloc() : memref<2048xi8>
678+
%c0 = arith.constant 0 : index
679+
%view = memref.view %alloc[%c0][] : memref<2048xi8> to memref<16x16xf32>
680+
%c1024 = arith.constant 1024 : index
681+
%view_0 = memref.view %alloc[%c1024][] : memref<2048xi8> to memref<16x16xf32>
682+
}
683+
"""
684+
)
685+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)