Skip to content

Commit 673ca35

Browse files
authored
Fix default FMA implementation for tensors with integer elements (#7419)
Several examples from our repo: ```bash FAILED language/test_core.py::test_dot[1-128-256-32-8-True-True-none-tf32-int8-int8-1-None0] - RuntimeError: PassManager::run failed FAILED language/test_core.py::test_dot[1-128-256-32-8-True-True-none-tf32-int8-int8-1-None1] - RuntimeError: PassManager::run failed FAILED language/test_core.py::test_dot[1-128-256-32-8-True-False-none-tf32-int8-int8-1-None0] - RuntimeError: PassManager::run failed ``` Most likely you use a different implementation in such cases. I could add a test for such cases, but need to somehow disable more advanced implementations (I'm not sure what the good way to do this is). --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8e79a35 commit 673ca35

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
22
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
#include "llvm/ADT/TypeSwitch.h"
34

45
using namespace mlir;
56
using namespace mlir::triton;
@@ -19,8 +20,26 @@ class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
1920
auto K = a.size();
2021
assert(b.size() == K);
2122
Value accum = c;
22-
for (auto [aElem, bElem] : llvm::zip(a, b))
23-
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
23+
Type tgtTy = accum.getType();
24+
for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) {
25+
const auto &aElem = std::get<0>(*it);
26+
const auto &bElem = std::get<1>(*it);
27+
28+
assert(aElem.getType() == tgtTy);
29+
assert(bElem.getType() == tgtTy);
30+
31+
// to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM
32+
// type or LLVM dialect-compatible vector of floating point LLVM type, but
33+
// got 'i32'
34+
llvm::TypeSwitch<Type>(tgtTy)
35+
.Case<FloatType>([&](auto) {
36+
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
37+
})
38+
.Case<IntegerType>([&](auto) {
39+
accum = builder.create<LLVM::AddOp>(
40+
loc, builder.create<LLVM::MulOp>(loc, aElem, bElem), accum);
41+
});
42+
}
2443
return accum;
2544
}
2645
};

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13471347

13481348
// -----
13491349

1350+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1351+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1352+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}>
1353+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}>
1354+
#smem = #ttg.shared_memory
1355+
module attributes {"ttg.target" = "cuda:70", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1356+
// CHECK-LABEL: matmul_fmadot_integer
1357+
tt.func @matmul_fmadot_integer(%ptr:!tt.ptr<i32> {tt.divisibility = 16 : i32},
1358+
%a:!ttg.memdesc<32x16xi32, #shared, #smem>, %b:!ttg.memdesc<16x32xi32, #shared, #smem>) {
1359+
%cst = arith.constant dense<0> : tensor<32x32xi32, #blocked>
1360+
// CHECK-NOT: llvm.intr.fmuladd
1361+
// CHECK: llvm.mul
1362+
// CHECK: llvm.add
1363+
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xi32, #shared, #smem> -> tensor<32x16xi32, #dot_operand_a>
1364+
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xi32, #shared, #smem> -> tensor<16x32xi32, #dot_operand_b>
1365+
1366+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xi32, #dot_operand_a> * tensor<16x32xi32, #dot_operand_b> -> tensor<32x32xi32, #blocked>
1367+
%30 = tt.splat %ptr : !tt.ptr<i32> -> tensor<32x1x!tt.ptr<i32>, #blocked>
1368+
%36 = tt.broadcast %30 : tensor<32x1x!tt.ptr<i32>, #blocked> -> tensor<32x32x!tt.ptr<i32>, #blocked>
1369+
tt.store %36, %28 : tensor<32x32x!tt.ptr<i32>, #blocked>
1370+
tt.return
1371+
}
1372+
}
1373+
1374+
// -----
1375+
13501376
#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
13511377
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
13521378
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>

0 commit comments

Comments
 (0)