Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 86347f6

Browse files
authored
[mlir][linalg] fix linalg.batch_reduce_matmul auto cast (#102585)
Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A * cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C`
1 parent 9239e15 commit 86347f6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,8 @@ def batch_reduce_matmul(
592592
"""
593593
domain(D.b, D.m, D.n, D.k)
594594
implements(ContractionOpInterface)
595-
C[D.m, D.n] += TypeFn.cast_signed(
596-
U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
595+
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
596+
U, B[D.b, D.k, D.n]
597597
)
598598

599599

0 commit comments

Comments
 (0)