Skip to content

Commit 2e6bf66

Browse files
committed
Make indexing_maps optional on matmul, as it should be
1 parent 18ff3a6 commit 2e6bf66

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def matmul(
153153
inputs: Sequence[Union[Operation, OpView, Value]],
154154
*,
155155
outs: Sequence[Union[Operation, OpView, Value]],
156-
indexing_maps: Sequence[AffineMapAttr],
156+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
157157
cast: Optional[Union[TypeFn, Attribute]] = None,
158158
):
159159
inputs = [_get_op_result_or_value(input) for input in inputs]

mlir/test/python/dialects/linalg/ops.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,10 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
307307
result_tensors=(C.type,),
308308
inputs=(A, B),
309309
outputs=(C,),
310-
indexing_maps=[a_map, b_map, c_map],
311310
)
312311
linalg.fill_builtin_region(res.operation)
313312
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
314-
res = linalg.matmul(
315-
(A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
316-
)
313+
res = linalg.matmul((A, B), outs=(C,))
317314

318315
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
319316
res = linalg.MatmulOp(
@@ -337,12 +334,11 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
337334
result_tensors=[],
338335
inputs=(Amem, Bmem),
339336
outputs=(Cmem,),
340-
indexing_maps=[a_map, b_map, c_map],
341337
)
342338
linalg.fill_builtin_region(res.operation)
343339
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
344340
linalg.matmul(
345-
(Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
341+
(Amem, Bmem), outs=(Cmem,)
346342
)
347343

348344
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)

0 commit comments

Comments
 (0)