Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 1a51cc2

Browse files
authored
[MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (#126377)
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes: 1) make linalg.matmul's parsing and printing consistent w.r.t. whether indexing_maps occurs before or after operands, i.e. per the tests cases it comes _before_. 2) tablegen for linalg.contract did not state it accepted an optional cast attr. 3) In ODS's C++-generating code, expand partial support for `$_builder` access in `Attr::defaultValue` to full support. This enables access to the current `MlirContext` when constructing the default value (as is required when the default value consists of affine maps).
1 parent 06ef50a commit 1a51cc2

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,49 @@ def __init__(
147147

148148

149149
generic = region_op(GenericOp_, terminator=YieldOp)
150+
151+
152+
def matmul(
153+
*ins: Union[Operation, OpView, Value],
154+
outs: Sequence[Union[Operation, OpView, Value]],
155+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
156+
cast: Optional[Union[TypeFn, Attribute]] = None,
157+
):
158+
ins = [_get_op_result_or_value(input) for input in ins]
159+
if len(outs) > 1:
160+
raise ValueError(f"{outs=} must have length 1.")
161+
init = _get_op_result_or_value(outs[0])
162+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
163+
164+
op = MatmulOp(
165+
result_tensors=result_types,
166+
inputs=ins,
167+
outputs=[init],
168+
indexing_maps=indexing_maps,
169+
cast=cast,
170+
)
171+
fill_builtin_region(op.operation)
172+
return op
173+
174+
175+
def contract(
176+
*ins: Union[Operation, OpView, Value],
177+
outs: Sequence[Union[Operation, OpView, Value]],
178+
indexing_maps: Sequence[AffineMapAttr],
179+
cast: Optional[Union[TypeFn, Attribute]] = None,
180+
):
181+
ins = [_get_op_result_or_value(input) for input in ins]
182+
if len(outs) > 1:
183+
raise ValueError(f"{outs=} must have length 1.")
184+
init = _get_op_result_or_value(outs[0])
185+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
186+
187+
op = ContractOp(
188+
result_tensors=result_types,
189+
inputs=ins,
190+
outputs=[init],
191+
indexing_maps=indexing_maps,
192+
cast=cast,
193+
)
194+
fill_builtin_region(op.operation)
195+
return op

0 commit comments

Comments
 (0)