Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 6f44960

Browse files
authored
[MLIR][Linalg] Introduce broadcast/transpose semantic to batch_matmul (#122275)
Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>) outs (%arg2: memref<2x3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 llvm/llvm-project#115319
1 parent 67b6d46 commit 6f44960

File tree

1 file changed

+0
-18
lines changed

1 file changed

+0
-18
lines changed

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -484,24 +484,6 @@ def batch_mmt4d(
484484
) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
485485

486486

487-
@linalg_structured_op
488-
def batch_matmul(
489-
A=TensorDef(T1, Batch, S.M, S.K),
490-
B=TensorDef(T2, Batch, S.K, S.N),
491-
C=TensorDef(U, Batch, S.M, S.N, output=True),
492-
):
493-
"""Performs a batched matrix multiplication of two 3D inputs.
494-
495-
Numeric casting is performed on the operands to the inner multiply, promoting
496-
them to the same data type as the accumulator/output.
497-
"""
498-
domain(D.b, D.m, D.n, D.k)
499-
implements(ContractionOpInterface)
500-
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
501-
U, B[D.b, D.k, D.n]
502-
)
503-
504-
505487
@linalg_structured_op
506488
def batch_matmul_transpose_a(
507489
A=TensorDef(T1, Batch, S.K, S.M),

0 commit comments

Comments
 (0)