Skip to content

Commit b1064f2

Browse files
antiagainstguacamoleo
authored andcommitted
[AMD] Add support for scaled_dot(mxfp4, -) (triton-lang#5034)
This commit adds support for mxfp4 typed A tensor for sacled dot in the AMD backend. We moved the `convertMxfp4x2ToBf16x2` impl from NVIDIA side to a common path to reuse.
1 parent 299fed6 commit b1064f2

File tree

7 files changed

+100
-65
lines changed

7 files changed

+100
-65
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,15 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
391391
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
392392
return base;
393393
}
394+
395+
// -----------------------------------------------------------------------
396+
// MXFP utilities
397+
// -----------------------------------------------------------------------
398+
399+
// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16
400+
// standalone values and returns them as a pair for (high 4 bits, low 4 bits).
401+
std::pair<Value, Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter,
402+
Location loc, Value v);
394403
} // namespace LLVM
395404

396405
/* ------------------------------------ */

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,5 +862,32 @@ SmallVector<Value> getWrappedMultiDimOffset(
862862
return multiDimOffsetWrapped;
863863
}
864864

865+
std::pair<Value, Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter,
866+
Location loc, Value v) {
867+
auto em0 = and_(v, i8_val(0x70));
868+
auto em1 = and_(v, i8_val(0x7));
869+
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)),
870+
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
871+
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)),
872+
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
873+
874+
// Three cases:
875+
// 1) x is normal and non-zero: Correct bias
876+
v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)),
877+
add(v0, i16_val((127 - 1) << 7)), v0);
878+
v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)),
879+
add(v1, i16_val((127 - 1) << 7)), v1);
880+
881+
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
882+
// bf16
883+
v0 = select(icmp_eq(em0, i8_val(0x10)),
884+
or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0);
885+
v1 = select(icmp_eq(em1, i8_val(0x1)),
886+
or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1);
887+
// 3) x is zero, nothing to do
888+
889+
return {v0, v1};
890+
}
891+
865892
} // namespace LLVM
866893
} // namespace mlir

python/test/unit/language/test_core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_cuda,
3030
is_interpreter,
3131
is_hip,
32+
is_hip_mi200,
3233
get_arch,
3334
torch_float8_dtypes,
3435
torch_dtypes,
@@ -3354,7 +3355,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
33543355
if cc < (8, 9):
33553356
pytest.skip("float8e4nv not supported on CUDA < 8.9")
33563357
if is_hip():
3357-
if type_a != "e5m2" or (type_b != "e5m2" and type_b != "bf16"):
3358+
if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]):
33583359
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
33593360
if mma == 16 and K == 64:
33603361
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
@@ -3530,7 +3531,13 @@ def make_finite(x, dtype):
35303531

35313532
z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)
35323533

3533-
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)
3534+
# Bigger tolerance for AMD MI200 devices.
3535+
# MI200 devices use reduced precision fp16 and bf16 and flush input and output denormal values
3536+
# to zero. Detailed info is at:
3537+
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
3538+
atol = 2e-4 if is_hip_mi200() else 1e-5
3539+
rtol = 2e-2 if is_hip_mi200() else 1e-2
3540+
torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol)
35343541

35353542
# make sure ld/st are vectorized
35363543
if is_cuda():

python/triton/_internal_testing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def is_hip():
4141
return False if target is None else target.backend == "hip"
4242

4343

44+
def is_hip_mi200():
45+
target = get_current_target()
46+
return target.backend == 'hip' and target.arch == 'gfx90a'
47+
48+
4449
def get_arch():
4550
target = get_current_target()
4651
return "" if target is None else str(target.arch)

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ namespace {
2121

2222
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v,
2323
Value scale) {
24+
Value vBf16 = bitcast(v, bf16_ty);
2425
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
2526
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
2627
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
27-
Value scaledBf16 = fmul(v, scaleBf16);
28+
Value scaledBf16 = fmul(vBf16, scaleBf16);
2829
// Account for NaN in the scale as per the mxfp specification.
2930
return select(scaleIsNan, nanBf16, scaledBf16);
3031
};
@@ -43,7 +44,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
4344
matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor,
4445
ConversionPatternRewriter &rewriter) const override {
4546
auto fpType = op.getFpType();
46-
if (!(fpType == ScaleDotElemType::E4M3 || fpType == ScaleDotElemType::E5M2))
47+
bool isPacked = fpType == ScaleDotElemType::E2M1;
48+
if (!(isPacked || fpType == ScaleDotElemType::E4M3 ||
49+
fpType == ScaleDotElemType::E5M2))
4750
return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases");
4851

4952
Location loc = op.getLoc();
@@ -56,7 +59,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
5659
// warp. MXFP spec mandates 1 scale value for every 32 onsecutive values
5760
// along the K dimension. So in total each thread should read 32x main
5861
// element values.
59-
if (xVals.size() != scaleVals.size() * 32)
62+
if (xVals.size() != scaleVals.size() * (isPacked ? 16 : 32))
6063
return rewriter.notifyMatchFailure(op, "unsupported problem size");
6164

6265
auto dotEncoding =
@@ -79,6 +82,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
7982
Value warpId = udiv(tid, warpSize);
8083
Value laneId = urem(tid, warpSize);
8184

85+
if (isPacked)
86+
xVals = unpackFP4Elements(loc, rewriter, xVals);
87+
8288
// Given that MFMA layout for the A tensor arranges thread in a column-major
8389
// manner, for the current tid, it's at row (tid % mDim). When we set up
8490
// blocked layout for the A scale tensor, we made sure that it has a
@@ -136,6 +142,20 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
136142
rewriter.replaceOp(op, result);
137143
return success();
138144
}
145+
146+
private:
147+
SmallVector<Value> unpackFP4Elements(Location loc, RewriterBase &rewriter,
148+
ArrayRef<Value> packed) const {
149+
// Split every fp4x2 into 2 bf16 values.
150+
llvm::SmallVector<Value> unpacked;
151+
unpacked.reserve(packed.size() * 2);
152+
for (Value v : packed) {
153+
auto [e0, e1] = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, v);
154+
unpacked.push_back(e0);
155+
unpacked.push_back(e1);
156+
}
157+
return unpacked;
158+
}
139159
};
140160
} // anonymous namespace
141161

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -507,9 +507,10 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
507507
ScaleDotElemType aElemType = dotOp.getLhsType();
508508
ScaleDotElemType bElemType = dotOp.getRhsType();
509509

510-
if (!(aElemType == ScaleDotElemType::E4M3 ||
510+
if (!(aElemType == ScaleDotElemType::E2M1 ||
511+
aElemType == ScaleDotElemType::E4M3 ||
511512
aElemType == ScaleDotElemType::E5M2))
512-
return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS");
513+
return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8/mxfp4 LHS");
513514
if (!(bElemType == ScaleDotElemType::E4M3 ||
514515
bElemType == ScaleDotElemType::E5M2 ||
515516
bElemType == ScaleDotElemType::BF16))
@@ -532,7 +533,16 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
532533
unsigned nDim = mfmaInstr.value().getNDim();
533534
unsigned kDim = mfmaInstr.value().getKDim();
534535
unsigned kBase = mfmaInstr.value().getKBase();
535-
unsigned kWdith = kBase *= kPack;
536+
537+
// If A tensor contains mxfp4, we pack every two values into one int8 value
538+
// there. For such cases, we have different initial kWidth for LHS and RHS,
539+
// which will be "fixed" later by using upcast_mxfp to convert LHS to
540+
// unpacked values. For such packed cases, we cannot support flexible kPack
541+
// choices from the developer--it just does not apply here. So mandate the
542+
// choice here.
543+
bool isPacked = aElemType == ScaleDotElemType::E2M1;
544+
unsigned kWdiths[] = {isPacked ? 4 : kBase * kPack,
545+
isPacked ? 8 : kBase * kPack};
536546

537547
// For A tensor, 32 consecutive elements along K dim share the same scale.
538548
// We'd like to keep the scale values together with the base values in the
@@ -553,38 +563,20 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
553563
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(
554564
dotOp.getC().getLoc(), newRetType, dotOp.getC());
555565

556-
// OCP mxfp8 requires implementations to follow OCP fp8 elements. We are
557-
// doing software emulation using bf16 here, so we map to OCP fp8 f8E4M3FN
558-
// and f8E5M2.
559-
auto enumToType = [&rewriter](ScaleDotElemType type) {
560-
switch (type) {
561-
case ScaleDotElemType::E4M3:
562-
return rewriter.getFloat8E4M3FNType();
563-
case ScaleDotElemType::E5M2:
564-
return rewriter.getFloat8E5M2Type();
565-
default:
566-
llvm_unreachable("unexpected fp type");
567-
}
568-
};
569-
570566
auto toMMABf16 = [&](TensorValue v, int idx,
571567
ScaleDotElemType type) -> TensorValue {
572-
assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3 ||
573-
type == ScaleDotElemType::BF16);
574-
575568
auto vType = v.getType();
576569
auto newVEncoding = DotOperandEncodingAttr::get(
577-
ctx, idx, newRetType.getEncoding(), kWdith);
570+
ctx, idx, newRetType.getEncoding(), kWdiths[idx]);
578571
auto newVType = RankedTensorType::get(
579572
vType.getShape(), vType.getElementType(), newVEncoding);
580573
v = rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
581574
if (type == ScaleDotElemType::BF16)
582575
return v;
583-
584-
auto vTypeFp8 = RankedTensorType::get(vType.getShape(), enumToType(type),
585-
newVEncoding);
586-
v = cast<TensorValue>(
587-
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
576+
// Don't need to covert int8 holding mxfp4 for A--the upcast_mxfp op can
577+
// take int8 tensor as input.
578+
if (idx == 0 && type == ScaleDotElemType::E2M1)
579+
return v;
588580

589581
auto vTypeBf16 = RankedTensorType::get(
590582
vType.getShape(), rewriter.getBF16Type(), newVEncoding);

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "mlir/Conversion/LLVMCommon/Pattern.h"
22
#include "mlir/IR/BuiltinAttributes.h"
33
#include "mlir/IR/BuiltinTypes.h"
4+
#include "mlir/IR/PatternMatch.h"
45
#include "mlir/IR/TypeUtilities.h"
56

67
#include "PatternTritonGPUOpToLLVM.h"
@@ -12,7 +13,6 @@
1213
#include "triton/Dialect/Triton/IR/Dialect.h"
1314
#include "llvm/ADT/STLExtras.h"
1415
#include "llvm/ADT/SmallVector.h"
15-
#include "llvm/Support/raw_ostream.h"
1616
#include <array>
1717

1818
using namespace mlir;
@@ -30,42 +30,17 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
3030
: ConvertOpToLLVMPattern<UpcastMXFPOp>(typeConverter, benefit),
3131
targetInfo(targetInfo) {}
3232

33-
llvm::SmallVector<Value>
34-
unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter,
35-
const llvm::SmallVector<Value> &vals, Value laneId) const {
36-
auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value {
37-
auto em0 = and_(v, i8_val(0x70));
38-
auto em1 = and_(v, i8_val(0x7));
39-
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)),
40-
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
41-
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)),
42-
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
43-
44-
// Three cases:
45-
// 1) x is normal and non-zero: Correct bias
46-
v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)),
47-
add(v0, i16_val((127 - 1) << 7)), v0);
48-
v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)),
49-
add(v1, i16_val((127 - 1) << 7)), v1);
50-
51-
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
52-
// bf16
53-
v0 = select(icmp_eq(em0, i8_val(0x10)),
54-
or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0);
55-
v1 = select(icmp_eq(em1, i8_val(0x1)),
56-
or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1);
57-
// 3) x is zero, nothing to do
58-
59-
// Swap as they come packed in big endian
60-
return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16)));
61-
};
33+
llvm::SmallVector<Value> unpackFP4Elements(Location loc,
34+
RewriterBase &rewriter,
35+
ArrayRef<Value> vals) const {
6236

63-
auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2](
64-
Value v) -> llvm::SmallVector<Value, 4> {
37+
auto fp4x8ToBf16x2 = [&loc, &rewriter](Value v) {
6538
llvm::SmallVector<Value, 4> results(4);
6639
for (int i = 0; i < 4; ++i) {
6740
auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i)));
68-
results[i] = fp4x2ToBf16x2(v_i);
41+
auto [e0, e1] = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, v_i);
42+
// Swap as they come packed in big endian
43+
results[i] = or_(zext(i32_ty, e0), shl(zext(i32_ty, e1), i32_val(16)));
6944
}
7045
return results;
7146
};
@@ -104,7 +79,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
10479
Value laneId = urem(tid, warpSize);
10580

10681
if (fpType == ScaleDotElemType::E2M1) {
107-
xVals = unpackFP4Elements(loc, rewriter, xVals, laneId);
82+
xVals = unpackFP4Elements(loc, rewriter, xVals);
10883
}
10984

11085
auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value {

0 commit comments

Comments
 (0)