Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 202b32b

Browse files
authored
[MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (#130944)
…_reduce_matmul. This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified. Example Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm/llvm-project#115319 llvm/llvm-project#122275
1 parent 346f919 commit 202b32b

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def batch_matmul(
203203
)
204204

205205

206+
def batch_reduce_matmul(
207+
*ins: Union[Operation, OpView, Value],
208+
outs: Sequence[Union[Operation, OpView, Value]],
209+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
210+
cast: Optional[Union[TypeFn, Attribute]] = None,
211+
):
212+
return _get_op_result_or_op_results(
213+
_create_matmul_like_op(
214+
BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
215+
)
216+
)
217+
218+
206219
def contract(
207220
*ins: Union[Operation, OpView, Value],
208221
outs: Sequence[Union[Operation, OpView, Value]],

0 commit comments

Comments
 (0)