Skip to content

Commit ba622c3

Browse files
authored
add output_shape to memref/tensor (#89)
1 parent 74e2bab commit ba622c3

File tree

4 files changed

+31
-23
lines changed

4 files changed

+31
-23
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,13 @@ def expand_shape(
210210

211211
return MemRef(
212212
memref.expand_shape(
213-
T.memref(*result_shape, inp.dtype), inp, reassoc_list, loc=loc, ip=ip
213+
T.memref(*result_shape, inp.dtype),
214+
inp,
215+
reassoc_list,
216+
output_shape=[],
217+
static_output_shape=result_shape,
218+
loc=loc,
219+
ip=ip,
214220
)
215221
)
216222

mlir/extras/dialects/ext/tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ def expand_dims(
325325
RankedTensorType.get(result_shape, inp.dtype),
326326
inp,
327327
reassoc_list,
328+
output_shape=[],
329+
static_output_shape=result_shape,
328330
loc=loc,
329331
ip=ip,
330332
)

tests/test_memref.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -343,27 +343,27 @@ def test_none_indices(ctx: MLIRContext):
343343
"""\
344344
module {
345345
%alloc = memref.alloc() : memref<10x22x333x4444xi32>
346-
%expand_shape = memref.expand_shape %alloc [[0, 1], [2], [3], [4]] : memref<10x22x333x4444xi32> into memref<1x10x22x333x4444xi32>
346+
%expand_shape = memref.expand_shape %alloc [[0, 1], [2], [3], [4]] output_shape [1, 10, 22, 333, 4444] : memref<10x22x333x4444xi32> into memref<1x10x22x333x4444xi32>
347347
%subview = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
348-
%expand_shape_0 = memref.expand_shape %subview [[0, 1], [2], [3], [4]] : memref<10x22x333x4444xi32> into memref<10x1x22x333x4444xi32>
348+
%expand_shape_0 = memref.expand_shape %subview [[0, 1], [2], [3], [4]] output_shape [10, 1, 22, 333, 4444] : memref<10x22x333x4444xi32> into memref<10x1x22x333x4444xi32>
349349
%subview_1 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
350-
%expand_shape_2 = memref.expand_shape %subview_1 [[0, 1, 2], [3], [4], [5]] : memref<10x22x333x4444xi32> into memref<1x10x1x22x333x4444xi32>
350+
%expand_shape_2 = memref.expand_shape %subview_1 [[0, 1, 2], [3], [4], [5]] output_shape [1, 10, 1, 22, 333, 4444] : memref<10x22x333x4444xi32> into memref<1x10x1x22x333x4444xi32>
351351
%subview_3 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
352-
%expand_shape_4 = memref.expand_shape %subview_3 [[0], [1, 2], [3], [4]] : memref<10x22x333x4444xi32> into memref<10x22x1x333x4444xi32>
352+
%expand_shape_4 = memref.expand_shape %subview_3 [[0], [1, 2], [3], [4]] output_shape [10, 22, 1, 333, 4444] : memref<10x22x333x4444xi32> into memref<10x22x1x333x4444xi32>
353353
%subview_5 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
354-
%expand_shape_6 = memref.expand_shape %subview_5 [[0], [1], [2, 3], [4]] : memref<10x22x333x4444xi32> into memref<10x22x333x1x4444xi32>
354+
%expand_shape_6 = memref.expand_shape %subview_5 [[0], [1], [2, 3], [4]] output_shape [10, 22, 333, 1, 4444] : memref<10x22x333x4444xi32> into memref<10x22x333x1x4444xi32>
355355
%subview_7 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
356-
%expand_shape_8 = memref.expand_shape %subview_7 [[0], [1], [2], [3, 4]] : memref<10x22x333x4444xi32> into memref<10x22x333x4444x1xi32>
356+
%expand_shape_8 = memref.expand_shape %subview_7 [[0], [1], [2], [3, 4]] output_shape [10, 22, 333, 4444, 1] : memref<10x22x333x4444xi32> into memref<10x22x333x4444x1xi32>
357357
%subview_9 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
358-
%expand_shape_10 = memref.expand_shape %subview_9 [[0], [1], [2], [3, 4]] : memref<10x22x333x4444xi32> into memref<10x22x333x4444x1xi32>
358+
%expand_shape_10 = memref.expand_shape %subview_9 [[0], [1], [2], [3, 4]] output_shape [10, 22, 333, 4444, 1] : memref<10x22x333x4444xi32> into memref<10x22x333x4444x1xi32>
359359
%subview_11 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
360-
%expand_shape_12 = memref.expand_shape %subview_11 [[0, 1], [2], [3], [4, 5]] : memref<10x22x333x4444xi32> into memref<10x1x22x333x4444x1xi32>
360+
%expand_shape_12 = memref.expand_shape %subview_11 [[0, 1], [2], [3], [4, 5]] output_shape [10, 1, 22, 333, 4444, 1] : memref<10x22x333x4444xi32> into memref<10x1x22x333x4444x1xi32>
361361
%subview_13 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
362-
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3], [4], [5, 6]] : memref<10x22x333x4444xi32> into memref<10x1x22x1x333x4444x1xi32>
362+
%expand_shape_14 = memref.expand_shape %subview_13 [[0, 1], [2, 3], [4], [5, 6]] output_shape [10, 1, 22, 1, 333, 4444, 1] : memref<10x22x333x4444xi32> into memref<10x1x22x1x333x4444x1xi32>
363363
%subview_15 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
364-
%expand_shape_16 = memref.expand_shape %subview_15 [[0, 1], [2, 3], [4, 5], [6, 7]] : memref<10x22x333x4444xi32> into memref<10x1x22x1x333x1x4444x1xi32>
364+
%expand_shape_16 = memref.expand_shape %subview_15 [[0, 1], [2, 3], [4, 5], [6, 7]] output_shape [10, 1, 22, 1, 333, 1, 4444, 1] : memref<10x22x333x4444xi32> into memref<10x1x22x1x333x1x4444x1xi32>
365365
%subview_17 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
366-
%expand_shape_18 = memref.expand_shape %subview_17 [[0, 1, 2], [3, 4], [5, 6], [7, 8]] : memref<10x22x333x4444xi32> into memref<1x10x1x22x1x333x1x4444x1xi32>
366+
%expand_shape_18 = memref.expand_shape %subview_17 [[0, 1, 2], [3, 4], [5, 6], [7, 8]] output_shape [1, 10, 1, 22, 1, 333, 1, 4444, 1] : memref<10x22x333x4444xi32> into memref<1x10x1x22x1x333x1x4444x1xi32>
367367
%subview_19 = memref.subview %alloc[0, 0, 0, 0] [10, 22, 333, 4444] [1, 1, 1, 1] : memref<10x22x333x4444xi32> to memref<10x22x333x4444xi32>
368368
}
369369
"""

tests/test_tensor.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,17 +223,17 @@ def test_none_indices(ctx: MLIRContext):
223223
"""\
224224
module {
225225
%0 = tensor.empty() : tensor<10x22x333x4444xi32>
226-
%expanded = tensor.expand_shape %0 [[0, 1], [2], [3], [4]] : tensor<10x22x333x4444xi32> into tensor<1x10x22x333x4444xi32>
227-
%expanded_0 = tensor.expand_shape %0 [[0, 1], [2], [3], [4]] : tensor<10x22x333x4444xi32> into tensor<10x1x22x333x4444xi32>
228-
%expanded_1 = tensor.expand_shape %0 [[0, 1, 2], [3], [4], [5]] : tensor<10x22x333x4444xi32> into tensor<1x10x1x22x333x4444xi32>
229-
%expanded_2 = tensor.expand_shape %0 [[0], [1, 2], [3], [4]] : tensor<10x22x333x4444xi32> into tensor<10x22x1x333x4444xi32>
230-
%expanded_3 = tensor.expand_shape %0 [[0], [1], [2, 3], [4]] : tensor<10x22x333x4444xi32> into tensor<10x22x333x1x4444xi32>
231-
%expanded_4 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] : tensor<10x22x333x4444xi32> into tensor<10x22x333x4444x1xi32>
232-
%expanded_5 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] : tensor<10x22x333x4444xi32> into tensor<10x22x333x4444x1xi32>
233-
%expanded_6 = tensor.expand_shape %0 [[0, 1], [2], [3], [4, 5]] : tensor<10x22x333x4444xi32> into tensor<10x1x22x333x4444x1xi32>
234-
%expanded_7 = tensor.expand_shape %0 [[0, 1], [2, 3], [4], [5, 6]] : tensor<10x22x333x4444xi32> into tensor<10x1x22x1x333x4444x1xi32>
235-
%expanded_8 = tensor.expand_shape %0 [[0, 1], [2, 3], [4, 5], [6, 7]] : tensor<10x22x333x4444xi32> into tensor<10x1x22x1x333x1x4444x1xi32>
236-
%expanded_9 = tensor.expand_shape %0 [[0, 1, 2], [3, 4], [5, 6], [7, 8]] : tensor<10x22x333x4444xi32> into tensor<1x10x1x22x1x333x1x4444x1xi32>
226+
%expanded = tensor.expand_shape %0 [[0, 1], [2], [3], [4]] output_shape [1, 10, 22, 333, 4444] : tensor<10x22x333x4444xi32> into tensor<1x10x22x333x4444xi32>
227+
%expanded_0 = tensor.expand_shape %0 [[0, 1], [2], [3], [4]] output_shape [10, 1, 22, 333, 4444] : tensor<10x22x333x4444xi32> into tensor<10x1x22x333x4444xi32>
228+
%expanded_1 = tensor.expand_shape %0 [[0, 1, 2], [3], [4], [5]] output_shape [1, 10, 1, 22, 333, 4444] : tensor<10x22x333x4444xi32> into tensor<1x10x1x22x333x4444xi32>
229+
%expanded_2 = tensor.expand_shape %0 [[0], [1, 2], [3], [4]] output_shape [10, 22, 1, 333, 4444] : tensor<10x22x333x4444xi32> into tensor<10x22x1x333x4444xi32>
230+
%expanded_3 = tensor.expand_shape %0 [[0], [1], [2, 3], [4]] output_shape [10, 22, 333, 1, 4444] : tensor<10x22x333x4444xi32> into tensor<10x22x333x1x4444xi32>
231+
%expanded_4 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [10, 22, 333, 4444, 1] : tensor<10x22x333x4444xi32> into tensor<10x22x333x4444x1xi32>
232+
%expanded_5 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [10, 22, 333, 4444, 1] : tensor<10x22x333x4444xi32> into tensor<10x22x333x4444x1xi32>
233+
%expanded_6 = tensor.expand_shape %0 [[0, 1], [2], [3], [4, 5]] output_shape [10, 1, 22, 333, 4444, 1] : tensor<10x22x333x4444xi32> into tensor<10x1x22x333x4444x1xi32>
234+
%expanded_7 = tensor.expand_shape %0 [[0, 1], [2, 3], [4], [5, 6]] output_shape [10, 1, 22, 1, 333, 4444, 1] : tensor<10x22x333x4444xi32> into tensor<10x1x22x1x333x4444x1xi32>
235+
%expanded_8 = tensor.expand_shape %0 [[0, 1], [2, 3], [4, 5], [6, 7]] output_shape [10, 1, 22, 1, 333, 1, 4444, 1] : tensor<10x22x333x4444xi32> into tensor<10x1x22x1x333x1x4444x1xi32>
236+
%expanded_9 = tensor.expand_shape %0 [[0, 1, 2], [3, 4], [5, 6], [7, 8]] output_shape [1, 10, 1, 22, 1, 333, 1, 4444, 1] : tensor<10x22x333x4444xi32> into tensor<1x10x1x22x1x333x1x4444x1xi32>
237237
}
238238
"""
239239
)

0 commit comments

Comments
 (0)