Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit c2dc4f5

Browse files
authored
[MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (#127614)
As linalg.batch_matmul has been moved into tablegen from OpDSL, its derived python wrapper no longer exist.This patch adds the required python wrapper. Also refactors the BatchmatmulOp printer to make it consistent with its parser.
1 parent bed39aa commit c2dc4f5

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def __init__(
149149
generic = region_op(GenericOp_, terminator=YieldOp)
150150

151151

152-
def matmul(
152+
def create_op(
153+
op_type,
153154
*ins: Union[Operation, OpView, Value],
154155
outs: Sequence[Union[Operation, OpView, Value]],
155156
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
@@ -161,7 +162,7 @@ def matmul(
161162
init = _get_op_result_or_value(outs[0])
162163
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
163164

164-
op = MatmulOp(
165+
op = op_type(
165166
result_tensors=result_types,
166167
inputs=ins,
167168
outputs=[init],
@@ -172,24 +173,32 @@ def matmul(
172173
return op
173174

174175

176+
def matmul(
177+
*ins: Union[Operation, OpView, Value],
178+
outs: Sequence[Union[Operation, OpView, Value]],
179+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
180+
cast: Optional[Union[TypeFn, Attribute]] = None,
181+
):
182+
return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast)
183+
184+
185+
def batch_matmul(
186+
*ins: Union[Operation, OpView, Value],
187+
outs: Sequence[Union[Operation, OpView, Value]],
188+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
189+
cast: Optional[Union[TypeFn, Attribute]] = None,
190+
):
191+
return create_op(
192+
BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
193+
)
194+
195+
175196
def contract(
176197
*ins: Union[Operation, OpView, Value],
177198
outs: Sequence[Union[Operation, OpView, Value]],
178199
indexing_maps: Sequence[AffineMapAttr],
179200
cast: Optional[Union[TypeFn, Attribute]] = None,
180201
):
181-
ins = [_get_op_result_or_value(input) for input in ins]
182-
if len(outs) > 1:
183-
raise ValueError(f"{outs=} must have length 1.")
184-
init = _get_op_result_or_value(outs[0])
185-
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
186-
187-
op = ContractOp(
188-
result_tensors=result_types,
189-
inputs=ins,
190-
outputs=[init],
191-
indexing_maps=indexing_maps,
192-
cast=cast,
202+
return create_op(
203+
ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
193204
)
194-
fill_builtin_region(op.operation)
195-
return op

0 commit comments

Comments
 (0)