Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 1c23921

Browse files
authored
[MLIR][Linalg] Remove matmul_transpose variants (#147961)
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. This is in line with the [plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863), and can be done since #104783 merged. See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245 Issues investigated: * pad transform tests that could use `matmul` instead, so change to that. * ArmSME test using transpose actually needed it, so changed to `matmul` + affine maps. Arm tests validated by @banach-space (thanks!!).
1 parent 04f9312 commit 1c23921

File tree

1 file changed

+0
-93
lines changed

1 file changed

+0
-93
lines changed

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -373,42 +373,6 @@ def quantized_matmul(
373373
)
374374

375375

376-
@linalg_structured_op
377-
def matmul_transpose_a(
378-
A=TensorDef(T1, S.K, S.N),
379-
B=TensorDef(T2, S.K, S.M),
380-
C=TensorDef(U, S.M, S.N, output=True),
381-
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
382-
):
383-
"""Performs a matrix multiplication of two 2D inputs with lhs operand
384-
transposed.
385-
386-
Numeric casting is performed on the operands to the inner multiply, promoting
387-
them to the same data type as the accumulator/output.
388-
"""
389-
domain(D.m, D.n, D.k)
390-
implements(ContractionOpInterface)
391-
C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
392-
393-
394-
@linalg_structured_op
395-
def matmul_transpose_b(
396-
A=TensorDef(T1, S.M, S.K),
397-
B=TensorDef(T2, S.N, S.K),
398-
C=TensorDef(U, S.M, S.N, output=True),
399-
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
400-
):
401-
"""Performs a matrix multiplication of two 2D inputs with rhs operand
402-
transposed.
403-
404-
Numeric casting is performed on the operands to the inner multiply, promoting
405-
them to the same data type as the accumulator/output.
406-
"""
407-
domain(D.m, D.n, D.k)
408-
implements(ContractionOpInterface)
409-
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
410-
411-
412376
@linalg_structured_op
413377
def mmt4d(
414378
lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
@@ -453,44 +417,6 @@ def batch_mmt4d(
453417
) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0])
454418

455419

456-
@linalg_structured_op
457-
def batch_matmul_transpose_a(
458-
A=TensorDef(T1, Batch, S.K, S.M),
459-
B=TensorDef(T2, Batch, S.K, S.N),
460-
C=TensorDef(U, Batch, S.M, S.N, output=True),
461-
):
462-
"""Performs a batched matrix multiplication of two 3D inputs where lhs operand
463-
has its non-batch dimensions transposed.
464-
465-
Numeric casting is performed on the operands to the inner multiply, promoting
466-
them to the same data type as the accumulator/output.
467-
"""
468-
domain(D.b, D.m, D.n, D.k)
469-
implements(ContractionOpInterface)
470-
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
471-
U, B[D.b, D.k, D.n]
472-
)
473-
474-
475-
@linalg_structured_op
476-
def batch_matmul_transpose_b(
477-
A=TensorDef(T1, Batch, S.M, S.K),
478-
B=TensorDef(T2, Batch, S.N, S.K),
479-
C=TensorDef(U, Batch, S.M, S.N, output=True),
480-
):
481-
"""Performs a batched matrix multiplication of two 3D inputs where rhs operand
482-
has its non-batch dimensions transposed.
483-
484-
Numeric casting is performed on the operands to the inner multiply, promoting
485-
them to the same data type as the accumulator/output.
486-
"""
487-
domain(D.b, D.m, D.n, D.k)
488-
implements(ContractionOpInterface)
489-
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
490-
U, B[D.b, D.n, D.k]
491-
)
492-
493-
494420
@linalg_structured_op
495421
def quantized_batch_matmul(
496422
A=TensorDef(T1, Batch, S.M, S.K),
@@ -512,25 +438,6 @@ def quantized_batch_matmul(
512438
) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
513439

514440

515-
@linalg_structured_op
516-
def batch_reduce_matmul(
517-
A=TensorDef(T1, Batch, S.M, S.K),
518-
B=TensorDef(T2, Batch, S.K, S.N),
519-
C=TensorDef(U, S.M, S.N, output=True),
520-
):
521-
"""Performs a batch-reduce matrix multiplication of two 3D inputs.
522-
The partial multiplication results are reduced into a 2D output.
523-
524-
Numeric casting is performed on the operands to the inner multiply, promoting
525-
them to the same data type as the accumulator/output.
526-
"""
527-
domain(D.b, D.m, D.n, D.k)
528-
implements(ContractionOpInterface)
529-
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
530-
U, B[D.b, D.k, D.n]
531-
)
532-
533-
534441
@linalg_structured_op
535442
def matvec(
536443
A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)

0 commit comments

Comments
 (0)