Skip to content

Commit 33e7dc2

Browse files
authored
[BACKEND] Implement BF16x3 trick (#7592)
**Update:** I have found that for better perf, we need to use 3-6 BF16 dot products but not more. My findings are at: https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766 The short version is that the Triton Bench tutorial matmul with F32 benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots. I think this is sufficient to move forward as a replacement for MI350s TF32 and is in line with what hipblas does: https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 There is a similar implementation in XLA as well: https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 -------- Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ```python def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ```python def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch.
1 parent 3251bb8 commit 33e7dc2

File tree

12 files changed

+243
-51
lines changed

12 files changed

+243
-51
lines changed

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def TT_InputPrecisionAttr : I32EnumAttr<
129129
[
130130
I32EnumAttrCase<"TF32", 0, "tf32">,
131131
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
132-
I32EnumAttrCase<"IEEE", 2, "ieee">
132+
I32EnumAttrCase<"IEEE", 2, "ieee">,
133+
I32EnumAttrCase<"BF16x3", 3, "bf16x3">,
134+
I32EnumAttrCase<"BF16x6", 4, "bf16x6">
133135
]>{
134136
let cppNamespace = "::mlir::triton";
135137
}

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,11 @@ def TT_DotOp : TT_Op<"dot", [Pure,
664664

665665
let description = [{
666666
$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC
667-
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
667+
when the inputs are f32. It can be one of: tf32, tf32x3, ieee, bf16x3, bf16x6.
668668
tf32: use TC with tf32 ops.
669669
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
670+
bf16x3: implement the 3xBF16 trick. For more info see the pass in F32DotTC.cpp
671+
bf16x6: implement the 6xBF16 trick. For more info see the pass in F32DotTC.cpp
670672
ieee: don't use TC, implement dot in software.
671673
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
672674
}];

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,22 @@ def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir:
177177
}
178178

179179
def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
180-
let summary = "3xTF32 trick";
180+
let summary = "Emulate dot-product tensor core precision using TF32s or BF16s";
181181

182182
let description = [{
183-
Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
184-
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385
183+
Generic pass to emulate/decompose f32 `DotOp` instructions.
184+
* Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
185+
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385.
186+
* Decompose fp32 `DotOp` instructions into BF16 operations.
187+
See https://arxiv.org/abs/1904.06376
185188
}];
186189

187-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
188-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
190+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
191+
let options = [
192+
Option<"emuTF32", "emu-tf32",
193+
"bool", /*default*/"false",
194+
"whether to handle InputPrecision TF32xN for Nvidia GPUs">
195+
];
189196
}
190197

191198
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

Lines changed: 136 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,134 @@
22
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
44

5-
namespace mlir {
6-
namespace triton {
7-
namespace gpu {
5+
namespace mlir::triton::gpu {
86

97
#define GEN_PASS_DEF_TRITONGPUF32DOTTC
108
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
119

1210
namespace {
1311

12+
template <typename T>
13+
auto convertValue(Value value, const FloatType &scalarToType,
14+
PatternRewriter &rewriter) -> mlir::Value {
15+
auto fromType = cast<RankedTensorType>(value.getType());
16+
auto toType = fromType.cloneWith(std::nullopt, scalarToType);
17+
return rewriter.create<T>(value.getLoc(), toType, value).getResult();
18+
}
19+
20+
auto splitF32(Value input, unsigned N, PatternRewriter &rewriter)
21+
-> llvm::SmallVector<Value, 3> {
22+
llvm::SmallVector<Value, 3> splitInputs;
23+
for (unsigned i = 0; i < N; ++i) {
24+
Value inputAsBF16 =
25+
convertValue<arith::TruncFOp>(input, rewriter.getBF16Type(), rewriter);
26+
if (i != N - 1) {
27+
Value inputAsF32 = convertValue<arith::ExtFOp>(
28+
inputAsBF16, rewriter.getF32Type(), rewriter);
29+
input = rewriter.create<arith::SubFOp>(input.getLoc(), input, inputAsF32);
30+
}
31+
splitInputs.push_back(inputAsBF16);
32+
}
33+
return splitInputs;
34+
}
35+
36+
bool isF32(Value operand) {
37+
return cast<RankedTensorType>(operand.getType()).getElementType().isF32();
38+
};
39+
40+
Value zeroLike(Value c, PatternRewriter &rewriter) {
41+
return rewriter.create<SplatOp>(c.getLoc(), c.getType(),
42+
rewriter.create<arith::ConstantOp>(
43+
c.getLoc(), rewriter.getF32FloatAttr(0)));
44+
};
45+
46+
Value dot(Value lhs, Value rhs, Value acc, PatternRewriter &rewriter,
47+
InputPrecision precision = InputPrecision::IEEE,
48+
uint32_t maxNumImpreciseAcc = 0) {
49+
return rewriter.create<DotOp>(lhs.getLoc(), lhs, rhs, acc, precision,
50+
maxNumImpreciseAcc);
51+
};
52+
53+
Value replaceNansWithZeros(Value value, PatternRewriter &rewriter) {
54+
auto nans = rewriter.create<arith::CmpFOp>(
55+
value.getLoc(), arith::CmpFPredicate::UNO, value, value);
56+
auto zero = zeroLike(value, rewriter);
57+
return rewriter.create<arith::SelectOp>(value.getLoc(), nans, zero, value);
58+
};
59+
60+
unsigned getBF16Count(triton::InputPrecision precision) {
61+
switch (precision) {
62+
default:
63+
return 0;
64+
case InputPrecision::BF16x3:
65+
// BF16x3 only needs the first 2 values derived from splitting an F32
66+
return 2;
67+
case InputPrecision::BF16x6:
68+
return 3;
69+
}
70+
}
71+
72+
// Implements 3xBF16 https://arxiv.org/abs/1904.06376
73+
// See also
74+
// https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152
75+
// As well as
76+
// https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330
77+
struct BF16xN : public OpRewritePattern<DotOp> {
78+
using OpRewritePattern::OpRewritePattern;
79+
80+
LogicalResult matchAndRewrite(DotOp dotOp,
81+
PatternRewriter &rewriter) const override {
82+
// BF16 indices and count
83+
const unsigned hi = 0;
84+
const unsigned mid = 1;
85+
const unsigned lo = 2;
86+
const unsigned N = getBF16Count(dotOp.getInputPrecision());
87+
88+
if (!isF32(dotOp.getA()) || !isF32(dotOp.getB()) || !N)
89+
return failure();
90+
91+
// Starting Values: a(0), a(1), a(2), b(0), b(1), b(2) and zero accumulator
92+
const auto lhs_parts = splitF32(dotOp.getA(), N, rewriter);
93+
const auto rhs_parts = splitF32(dotOp.getB(), N, rewriter);
94+
auto result = zeroLike(dotOp.getC(), rewriter);
95+
96+
switch (dotOp.getInputPrecision()) {
97+
default:
98+
assert(false && "BF16DotTCPass expects BF16x6 or BF16x3");
99+
return failure();
100+
101+
// clang-format off
102+
// NOTE: 9 dots possible; handled like so if not for lack of speedup:
103+
// case InputPrecision::BF16x9:
104+
// result = dot(lhs_parts[lo], rhs_parts[lo], result, rewriter);
105+
// result = dot(lhs_parts[mid], rhs_parts[lo], result, rewriter);
106+
// result = dot(lhs_parts[lo], rhs_parts[mid], result, rewriter);
107+
// clang-format on
108+
109+
case InputPrecision::BF16x6:
110+
result = dot(lhs_parts[mid], rhs_parts[mid], result, rewriter);
111+
112+
result = dot(lhs_parts[lo], rhs_parts[hi], result, rewriter);
113+
result = dot(lhs_parts[hi], rhs_parts[lo], result, rewriter);
114+
115+
case InputPrecision::BF16x3:
116+
result = dot(lhs_parts[mid], rhs_parts[hi], result, rewriter);
117+
result = dot(lhs_parts[hi], rhs_parts[mid], result, rewriter);
118+
result = replaceNansWithZeros(result, rewriter);
119+
120+
// NOTE: For BF16x1 bail without replaceNansWithZeros
121+
// case InputPrecision::BF16x1: break;
122+
}
123+
124+
result = dot(lhs_parts[hi], rhs_parts[hi], result, rewriter);
125+
result =
126+
rewriter.create<arith::AddFOp>(dotOp.getLoc(), result, dotOp.getC());
127+
128+
rewriter.replaceOp(dotOp, result);
129+
return success();
130+
}
131+
};
132+
14133
// nb. We call the trick TF32x3 as C++ disallows variables starting with numbers
15134
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
16135
// For a, b f32
@@ -28,11 +147,6 @@ class TF32x3 : public OpRewritePattern<DotOp> {
28147

29148
LogicalResult matchAndRewrite(DotOp dotOp,
30149
PatternRewriter &rewriter) const override {
31-
32-
auto isF32 = [](Value operand) {
33-
return cast<RankedTensorType>(operand.getType()).getElementType().isF32();
34-
};
35-
36150
if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 &&
37151
isF32(dotOp.getA()) && isF32(dotOp.getB()))) {
38152
return failure();
@@ -47,41 +161,25 @@ class TF32x3 : public OpRewritePattern<DotOp> {
47161
ArrayRef<Value>{value})
48162
.getResult()[0];
49163
};
50-
auto zeroLike = [&](Value c) -> Value {
51-
return rewriter.create<SplatOp>(
52-
dotOp->getLoc(), c.getType(),
53-
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
54-
rewriter.getF32FloatAttr(0)));
55-
};
56164
auto add = [&](Value a, Value b) -> Value {
57165
return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
58166
};
59167
auto sub = [&](Value a, Value b) -> Value {
60168
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
61169
};
62-
auto dot = [&](Value a, Value b, Value c) -> Value {
63-
return rewriter.create<DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
64-
InputPrecision::TF32,
65-
dotOp.getMaxNumImpreciseAcc());
66-
};
67-
auto replaceNansWithZeros = [&](Value value) -> Value {
68-
auto nans = rewriter.create<arith::CmpFOp>(
69-
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
70-
auto zero = zeroLike(value);
71-
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
72-
value);
73-
};
74170

75171
auto aBig = f32ToTF32(dotOp.getA());
76172
auto aSmall = sub(dotOp.getA(), aBig);
77173

78174
auto bBig = f32ToTF32(dotOp.getB());
79175
auto bSmall = sub(dotOp.getB(), bBig);
80176

81-
auto zero = zeroLike(dotOp.getC());
177+
auto zero = zeroLike(dotOp.getC(), rewriter);
82178

83-
auto dot1 = dot(aSmall, bBig, zero);
84-
auto dot2 = dot(aBig, bSmall, dot1);
179+
auto dot1 = dot(aSmall, bBig, zero, rewriter, InputPrecision::TF32,
180+
dotOp.getMaxNumImpreciseAcc());
181+
auto dot2 = dot(aBig, bSmall, dot1, rewriter, InputPrecision::TF32,
182+
dotOp.getMaxNumImpreciseAcc());
85183

86184
// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
87185
// If rhs is +infinity, we will have:
@@ -90,8 +188,9 @@ class TF32x3 : public OpRewritePattern<DotOp> {
90188
// We would get the wrong result if we sum these partial products. Instead,
91189
// we must override any accumulated result if the last partial product is
92190
// non-finite.
93-
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
94-
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);
191+
auto dot2withZeroedNans = replaceNansWithZeros(dot2, rewriter);
192+
auto dot3 = dot(aBig, bBig, dot2withZeroedNans, rewriter,
193+
InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc());
95194

96195
auto sum = add(dot3, dotOp.getC());
97196

@@ -103,18 +202,20 @@ class TF32x3 : public OpRewritePattern<DotOp> {
103202
} // anonymous namespace
104203

105204
struct F32DotTCPass : public impl::TritonGPUF32DotTCBase<F32DotTCPass> {
205+
using impl::TritonGPUF32DotTCBase<F32DotTCPass>::TritonGPUF32DotTCBase;
106206
void runOnOperation() override {
107207
MLIRContext *context = &getContext();
108208
ModuleOp m = getOperation();
109209

110210
RewritePatternSet decomposePatterns(context);
111-
decomposePatterns.add<TF32x3>(context);
211+
if (this->emuTF32) {
212+
decomposePatterns.add<TF32x3>(context);
213+
}
214+
decomposePatterns.add<BF16xN>(context);
112215
if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) {
113216
signalPassFailure();
114217
}
115218
}
116219
};
117220

118-
} // namespace gpu
119-
} // namespace triton
120-
} // namespace mlir
221+
} // namespace mlir::triton::gpu

python/src/ir.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ void init_triton_ir(py::module &&m) {
308308
.value("TF32", InputPrecision::TF32)
309309
.value("TF32x3", InputPrecision::TF32x3)
310310
.value("IEEE", InputPrecision::IEEE)
311+
.value("BF16x3", InputPrecision::BF16x3)
312+
.value("BF16x6", InputPrecision::BF16x6)
311313
.export_values();
312314

313315
py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())

python/src/passes.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
7171
ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul);
7272
ADD_PASS_WRAPPER_0("add_reorder_instructions",
7373
createTritonGPUReorderInstructions);
74-
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC);
74+
ADD_PASS_OPTION_WRAPPER_1("add_f32_dot_tc", createTritonGPUF32DotTC, bool);
7575
ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands",
7676
createTritonGPUOptimizeDotOperands, bool);
7777
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",

python/test/unit/language/test_core.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3084,7 +3084,7 @@ def get_test_dot_base_cases():
30843084
return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None)
30853085
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
30863086
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
3087-
for input_precision in ['tf32', 'tf32x3', 'ieee']
3087+
for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16x3', 'bf16x6']
30883088
for in_dtype, out_dtype in [('float16', 'float16'), ('float16',
30893089
'float32'), ('float32',
30903090
'float32'), ('float64', 'float64')]
@@ -3209,6 +3209,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
32093209
if is_interpreter():
32103210
if in_dtype == 'bfloat16':
32113211
pytest.skip("bfloat16 is not supported in the interpreter")
3212+
if input_precision == "bf16x3" or input_precision == "bf16x6":
3213+
pytest.skip(f"input_precision {input_precision} is not supported in the interpreter")
32123214
else:
32133215
if not is_hip() and K < 16:
32143216
pytest.skip("small dots are supported only on HIP at the moment")
@@ -3238,7 +3240,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
32383240
pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12")
32393241
if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3():
32403242
pytest.skip(f"{in_dtype} only supported on CDNA3")
3241-
if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())):
3243+
if not ((input_precision in ("bf16x3", "bf16x6")) or (input_precision == "ieee") or
3244+
(input_precision == "tf32" and is_hip_cdna3())):
32423245
pytest.skip(f"{input_precision} not supported on HIP")
32433246
if kpack == 2 and in_dtype == 'int8' and K < 64:
32443247
pytest.skip("kpack too large for K")
@@ -3426,7 +3429,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
34263429

34273430
if in_dtype == 'float32' and input_precision != "ieee":
34283431
if is_tcgen5:
3429-
assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx)
3432+
if input_precision in ("bf16x3", "bf16x6"):
3433+
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
3434+
else:
3435+
assert re.search(r'tcgen05.mma.cta_group::1.kind::tf32', ptx)
3436+
elif input_precision in ("bf16x3", "bf16x6"):
3437+
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx)
34303438
else:
34313439
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
34323440
elif in_dtype == 'float16' and out_dtype == tl.float32:

python/triton/language/semantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,10 @@ def _str_to_dot_input_precision(self, input_precision):
14671467
input_precision = input_precision.upper()
14681468
if input_precision == "TF32X3":
14691469
input_precision = "TF32x3"
1470+
if input_precision == "BF16X3":
1471+
input_precision = "BF16x3"
1472+
if input_precision == "BF16X6":
1473+
input_precision = "BF16x6"
14701474
return getattr(ir.INPUT_PRECISION, input_precision)
14711475

14721476
def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],

0 commit comments

Comments
 (0)