Skip to content

Commit 061bbc5

Browse files
authored
[torch] Update torch.bmm to use accumulator type (llvm#3924)
Batch matmul was using the result type as the accumulator. Updated to use the preferred accumulator based on input type.
1 parent e68560d commit 061bbc5

File tree

2 files changed

+31
-2
lines changed
  • lib/Conversion/TorchToLinalg
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

2 files changed

+31
-2
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,15 +727,21 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
727727
// Check the matrixs shapes are valid for mulplication.
728728
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
729729

730+
Type accumulatorDType = getDefaultAccType(rewriter, resultElementType);
730731
Value initTensor0 = createZeroInitTensor(
731-
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2},
732-
resultElementType);
732+
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType);
733733

734734
Value bmm =
735735
rewriter
736736
.create<linalg::BatchMatmulOp>(loc, initTensor0.getType(),
737737
ValueRange{lhs, rhs}, initTensor0)
738738
.getResult(0);
739+
740+
if (accumulatorDType != resultElementType) {
741+
bmm = torch_to_linalg::convertTensorToElementType(rewriter, loc, bmm,
742+
resultElementType);
743+
}
744+
739745
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, bmm);
740746
return success();
741747
}

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,29 @@ def BmmFloatModule_basic(module, tu: TestUtils):
8787
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
8888

8989

90+
class BmmFloat16Module(torch.nn.Module):
91+
def __init__(self):
92+
super().__init__()
93+
94+
@export
95+
@annotate_args(
96+
[
97+
None,
98+
([-1, -1, -1], torch.float16, True),
99+
([-1, -1, -1], torch.float16, True),
100+
]
101+
)
102+
def forward(self, lhs, rhs):
103+
return torch.bmm(lhs, rhs)
104+
105+
106+
@register_test_case(module_factory=lambda: BmmFloat16Module())
107+
def BmmFloat16Module_basic(module, tu: TestUtils):
108+
module.forward(
109+
tu.rand(3, 4, 5).to(torch.float16), tu.rand(3, 5, 4).to(torch.float16)
110+
)
111+
112+
90113
class BmmIntModule(torch.nn.Module):
91114
def __init__(self):
92115
super().__init__()

0 commit comments

Comments
 (0)