Skip to content

Commit 22ac447

Browse files
authored
[FRONTEND][BACKEND] plumb fast_math attribute from scaled_dot frontend to LLVM codegen. Ignore NaN when set. (triton-lang#5582)
1 parent 199fd8a commit 22ac447

File tree

12 files changed

+65
-38
lines changed

12 files changed

+65
-38
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
480480
ArrayRef<Value> values);
481481

482482
// Scale a mxfp4 value by a given scale.
483-
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);
483+
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
484+
bool fastMath);
484485

485486
} // namespace LLVM
486487

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,8 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
690690
Optional<RankedTensorOf<[I8]>>:$lhs_scale,
691691
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
692692
TT_ScaleDotElemTypeAttr:$lhs_type,
693-
TT_ScaleDotElemTypeAttr:$rhs_type
693+
TT_ScaleDotElemTypeAttr:$rhs_type,
694+
BoolAttr:$fastMath
694695
);
695696

696697
let results = (outs TT_FloatTensor:$d);

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,13 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
292292
Compute the bf16 encoded in the given mxfp number as per
293293
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
294294
}];
295-
let arguments = (ins
296-
TT_Tensor:$src,
297-
TT_Tensor:$scale,
298-
TT_ScaleDotElemTypeAttr:$fp_type);
295+
let arguments = (
296+
ins
297+
TT_Tensor:$src,
298+
TT_Tensor:$scale,
299+
TT_ScaleDotElemTypeAttr:$fp_type,
300+
BoolAttr:$fastMath
301+
);
299302
let results = (outs TT_Tensor:$result);
300303

301304
let assemblyFormat = [{

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -904,13 +904,15 @@ SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
904904
return results;
905905
}
906906

907-
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v,
908-
Value scale) {
907+
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
908+
bool fastMath) {
909909
Value vBf16 = bitcast(v, bf16_ty);
910-
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
911-
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
912910
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
913911
Value scaledBf16 = fmul(vBf16, scaleBf16);
912+
if (fastMath)
913+
return scaledBf16;
914+
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
915+
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
914916
// Account for NaN in the scale as per the mxfp specification.
915917
return select(scaleIsNan, nanBf16, scaledBf16);
916918
};

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ class DecomposeScaledBlocked
398398
auto scale = scaledDotOp.getLhsScale();
399399
auto aType = scaledDotOp.getLhsType();
400400
auto bType = scaledDotOp.getRhsType();
401+
bool fastMath = scaledDotOp.getFastMath();
401402

402403
auto rank = oldRetType.getShape().size();
403404
if (rank != 2)
@@ -510,15 +511,17 @@ class DecomposeScaledBlocked
510511
newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
511512
}
512513

513-
a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding);
514+
a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding,
515+
fastMath);
514516

515517
Operation *newDot = nullptr;
516518
if (versionMajor == 2) {
517519
// Upcast B operand
518520
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
519521
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
520522
b = createArg(rewriter, b, 1, bType, newBEncoding,
521-
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
523+
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt,
524+
fastMath);
522525
newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), newRetType, a, b,
523526
newAcc);
524527
} else {
@@ -541,7 +544,7 @@ class DecomposeScaledBlocked
541544
createArg(mlir::PatternRewriter &rewriter, TypedValue<RankedTensorType> v,
542545
int idx, ScaleDotElemType type, std::optional<Attribute> vEncoding,
543546
std::optional<TypedValue<RankedTensorType>> opt_scale,
544-
std::optional<Attribute> scaleEncoding) const {
547+
std::optional<Attribute> scaleEncoding, bool fastMath) const {
545548
auto ctx = rewriter.getContext();
546549
// Create a new tensor with a given encoding or remove the encoding
547550
auto maybeWithEncoding =
@@ -576,7 +579,7 @@ class DecomposeScaledBlocked
576579
auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType(
577580
ret, type, Builder(v.getContext()).getBF16Type());
578581
ret = rewriter.create<triton::gpu::UpcastMXFPOp>(v.getLoc(), retTy, ret,
579-
scale, type);
582+
scale, type, fastMath);
580583
}
581584
return ret;
582585
}
@@ -589,6 +592,7 @@ class DecomposeScaledBlocked
589592
auto scale = scaledDotOp.getLhsScale();
590593
auto aType = scaledDotOp.getLhsType();
591594
auto bType = scaledDotOp.getRhsType();
595+
bool fastMath = scaledDotOp.getFastMath();
592596

593597
// create a DotOp to be passed in to getMMAVersionSafe
594598
// We don't pass encodings as we just want to get the type and shape
@@ -597,15 +601,16 @@ class DecomposeScaledBlocked
597601
// end up in the graph
598602
RankedTensorType aTType =
599603
createArg(rewriter, a, 0, aType, /*vEncoding=*/std::nullopt, scale,
600-
/*scaleEncoding=*/std::nullopt)
604+
/*scaleEncoding=*/std::nullopt, fastMath)
601605
.getType();
602606
auto aTypeNoEnc =
603607
RankedTensorType::get(aTType.getShape(), aTType.getElementType());
604608
a = rewriter.create<ConvertLayoutOp>(scaledDotOp.getLoc(), aTypeNoEnc, a);
605609

606610
RankedTensorType bTType =
607611
createArg(rewriter, b, 1, bType, /*vEncoding=*/std::nullopt,
608-
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt)
612+
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt,
613+
fastMath)
609614
.getType();
610615
auto bTypeNoEnc =
611616
RankedTensorType::get(bTType.getShape(), bTType.getElementType());
@@ -752,7 +757,7 @@ static Operation *transposeDotOp(DotScaledOp dotOp) {
752757
Value result = builder.create<DotScaledOp>(
753758
dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed,
754759
cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(),
755-
dotOp.getLhsType());
760+
dotOp.getLhsType(), dotOp.getFastMath());
756761
Operation *transposedResult =
757762
builder.create<TransOp>(result.getLoc(), result, transOrder);
758763
dotOp.replaceAllUsesWith(transposedResult);

python/src/ir.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,10 +1513,12 @@ void init_triton_ir(py::module &&m) {
15131513
std::optional<mlir::Value> &lhs_scale,
15141514
ScaleDotElemType lhs_format, mlir::Value &rhs,
15151515
std::optional<mlir::Value> &rhs_scale,
1516-
ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value {
1517-
return self.create<DotScaledOp>(
1518-
c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()),
1519-
rhs_scale.value_or(Value()), lhs_format, rhs_format);
1516+
ScaleDotElemType rhs_format, bool fast_math,
1517+
mlir::Value &c) -> mlir::Value {
1518+
return self.create<DotScaledOp>(c.getType(), lhs, rhs, c,
1519+
lhs_scale.value_or(Value()),
1520+
rhs_scale.value_or(Value()),
1521+
lhs_format, rhs_format, fast_math);
15201522
})
15211523
.def("create_floor",
15221524
[](TritonOpBuilder &self, Value &val) -> Value {

python/triton/language/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,8 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
17331733

17341734

17351735
@builtin
1736-
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None):
1736+
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, fast_math=False, acc=None, out_dtype=float32,
1737+
_builder=None):
17371738
"""
17381739
Returns the matrix product of two blocks in microscaling format.
17391740
@@ -1763,7 +1764,8 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
17631764
"""
17641765
out_dtype = _constexpr_to_value(out_dtype)
17651766
assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
1766-
return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder)
1767+
return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, fast_math, acc, out_dtype,
1768+
_builder)
17671769

17681770

17691771
# -----------------------

python/triton/language/semantic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,7 +1562,8 @@ def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
15621562

15631563

15641564
def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1565-
rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1565+
rhs_format: str, fast_math: bool, acc: tl.tensor | None, out_dtype: tl.dtype,
1566+
builder: ir.builder) -> tl.tensor:
15661567
assert lhs.type.is_block() and rhs.type.is_block()
15671568
#TODO: validate types.
15681569
lhs_rank = len(lhs.shape)
@@ -1601,7 +1602,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te
16011602
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
16021603
return tl.tensor(
16031604
builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1604-
rhs_format_enum, acc_handle), ret_ty)
1605+
rhs_format_enum, fast_math, acc_handle), ret_ty)
16051606

16061607

16071608
# ===----------------------------------------------------------------------===//

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2110,7 +2110,15 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m
21102110
// CHECK-COUNT-4: llvm.inline_asm
21112111
// CHECK-COUNT-2: nvvm.shfl.sync
21122112
// CHECK-COUNT-32: llvm.fmul
2113-
%0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
2113+
// CHECK: llvm.icmp
2114+
// CHECK: llvm.select
2115+
%0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
2116+
// CHECK-COUNT-4: llvm.inline_asm
2117+
// CHECK-COUNT-2: nvvm.shfl.sync
2118+
// CHECK-COUNT-32: llvm.fmul
2119+
// CHECK-NOT: llvm.icmp
2120+
// CHECK-NOT: llvm.select
2121+
%1 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 {fastMath = true} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
21142122
tt.return
21152123
}
21162124

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,10 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
204204
%b_bf16: tensor<64x128xbf16, #blocked>
205205
) -> tensor<128x128xf32, #blocked> {
206206
// CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}>
207-
// CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>>
208-
// CHECK: ttng.warp_group_dot
207+
// CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 {fastMath = false} : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>>
208+
// CHECK-NEXT: ttng.warp_group_dot {{.*}}
209209
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
210-
%result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
210+
%result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 {fastMath = false} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
211211
tt.return %result : tensor<128x128xf32, #blocked>
212212
}
213213

@@ -220,9 +220,9 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
220220
) -> tensor<128x128xf32, #blocked> {
221221
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
222222
// CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]>
223-
// CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>>
223+
// CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 {fastMath = true} : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>>
224224
// CHECK: tt.dot
225-
%result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked>
225+
%result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 {fastMath = true} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked>
226226
tt.return %result : tensor<128x128xf32, #blocked>
227227
}
228228
}
@@ -246,7 +246,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
246246
%0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 {
247247
// CHECK-DAG: tt.trans %{{.*}} {order = array<i32: 1, 0>} : tensor<128x64xf8E4M3FN, #{{.*}}> -> tensor<64x128xf8E4M3FN, #{{.*}}>
248248
// CHECK-DAG: tt.trans %a{{.*}} {order = array<i32: 1, 0>} : tensor<32x32xi8, #{{.*}}> -> tensor<32x32xi8, #{{.*}}>
249-
%3 = tt.dot_scaled %arg0, %arg1 scale %arg2, %arg5 lhs = e4m3 rhs = e2m1 : tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1>
249+
%3 = tt.dot_scaled %arg0, %arg1 scale %arg2, %arg5 lhs = e4m3 rhs = e2m1 {fastMath = false}: tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1>
250250
// CHECK: tt.dot
251251
// CHECK-NOT: tt.trans
252252
// CHECK: scf.yield

0 commit comments

Comments
 (0)