Skip to content

Commit f83f7ae

Browse files
committed
Switch argument format to that of OpDSL-derived linalg ops
1 parent c81c97f commit f83f7ae

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,21 +150,20 @@ def __init__(
150150

151151

152152
def matmul(
153-
inputs: Sequence[Union[Operation, OpView, Value]],
154-
*,
153+
*ins: Union[Operation, OpView, Value],
155154
outs: Sequence[Union[Operation, OpView, Value]],
156155
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
157156
cast: Optional[Union[TypeFn, Attribute]] = None,
158157
):
159-
inputs = [_get_op_result_or_value(input) for input in inputs]
158+
ins = [_get_op_result_or_value(input) for input in ins]
160159
if len(outs) > 1:
161160
raise ValueError(f"{outs=} must have length 1.")
162161
init = _get_op_result_or_value(outs[0])
163162
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
164163

165164
op = MatmulOp(
166165
result_tensors=result_types,
167-
inputs=inputs,
166+
inputs=ins,
168167
outputs=[init],
169168
indexing_maps=indexing_maps,
170169
cast=cast,
@@ -174,21 +173,20 @@ def matmul(
174173

175174

176175
def contract(
177-
inputs: Sequence[Union[Operation, OpView, Value]],
178-
*,
176+
*ins: Union[Operation, OpView, Value],
179177
outs: Sequence[Union[Operation, OpView, Value]],
180178
indexing_maps: Sequence[AffineMapAttr],
181179
cast: Optional[Union[TypeFn, Attribute]] = None,
182180
):
183-
inputs = [_get_op_result_or_value(input) for input in inputs]
181+
ins = [_get_op_result_or_value(input) for input in ins]
184182
if len(outs) > 1:
185183
raise ValueError(f"{outs=} must have length 1.")
186184
init = _get_op_result_or_value(outs[0])
187185
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
188186

189187
op = ContractOp(
190188
result_tensors=result_types,
191-
inputs=inputs,
189+
inputs=ins,
192190
outputs=[init],
193191
indexing_maps=indexing_maps,
194192
cast=cast,

mlir/test/python/dialects/linalg/ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
310310
)
311311
linalg.fill_builtin_region(res.operation)
312312
# CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
313-
res = linalg.matmul((A, B), outs=(C,))
313+
res = linalg.matmul(A, B, outs=(C,))
314314

315315
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
316316
res = linalg.MatmulOp(
@@ -322,7 +322,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
322322
linalg.fill_builtin_region(res.operation)
323323
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
324324
res = linalg.matmul(
325-
(A, Btransposed),
325+
A, Btransposed,
326326
outs=(C,),
327327
indexing_maps=[a_map, b_transposed_map, c_map],
328328
)
@@ -337,7 +337,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
337337
)
338338
linalg.fill_builtin_region(res.operation)
339339
# CHECK: linalg.matmul ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
340-
linalg.matmul((Amem, Bmem), outs=(Cmem,))
340+
linalg.matmul(Amem, Bmem, outs=(Cmem,))
341341

342342
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
343343
res = linalg.MatmulOp(
@@ -349,7 +349,7 @@ def matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
349349
linalg.fill_builtin_region(res.operation)
350350
# CHECK: linalg.matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
351351
linalg.matmul(
352-
(Amem, Btransposedmem),
352+
Amem, Btransposedmem,
353353
outs=(Cmem,),
354354
indexing_maps=[a_map, b_transposed_map, c_map],
355355
)
@@ -414,7 +414,7 @@ def matmul_as_contract_op(
414414
linalg.fill_builtin_region(op4.operation)
415415
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[B]] : tensor<4x8xf32>, tensor<8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
416416
op5 = linalg.contract(
417-
(A, B), outs=(C,), indexing_maps=[a_map, b_map, c_map]
417+
A, B, outs=(C,), indexing_maps=[a_map, b_map, c_map]
418418
)
419419

420420
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
@@ -427,7 +427,7 @@ def matmul_as_contract_op(
427427
linalg.fill_builtin_region(op4.operation)
428428
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<4x8xf32>, tensor<12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
429429
op5 = linalg.contract(
430-
(A, Btransposed),
430+
A, Btransposed,
431431
outs=(C,),
432432
indexing_maps=[a_map, b_transposed_map, c_map],
433433
)
@@ -443,7 +443,7 @@ def matmul_as_contract_op(
443443
linalg.fill_builtin_region(op4.operation)
444444
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$B_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[Bmem]] : memref<4x8xf32>, memref<8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
445445
linalg.contract(
446-
(Amem, Bmem), outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
446+
Amem, Bmem, outs=(Cmem,), indexing_maps=[a_map, b_map, c_map]
447447
)
448448

449449
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
@@ -456,7 +456,7 @@ def matmul_as_contract_op(
456456
linalg.fill_builtin_region(op4.operation)
457457
# CHECK: linalg.contract indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<4x8xf32>, memref<12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
458458
linalg.contract(
459-
(Amem, Btransposedmem),
459+
Amem, Btransposedmem,
460460
outs=(Cmem,),
461461
indexing_maps=[a_map, b_transposed_map, c_map],
462462
)

0 commit comments

Comments
 (0)