Skip to content

Commit 3121ad5

Browse files
authored
[AMD] Enable packed Bf8/Fp8->Bf16 conversions for gfx950 (triton-lang#6291)
Support Bf8/Fp8->Bf16 conversions with ROCDL wrappers `rocdl.cvt.scalef32.pk.bf16.*` in gfx950.
1 parent e489d68 commit 3121ad5

File tree

2 files changed

+60
-23
lines changed

2 files changed

+60
-23
lines changed

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
7878
%0 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
7979

8080
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f16.bf8
81-
%2 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
81+
%1 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
82+
83+
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf16.bf8
84+
%2 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
8285

8386
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f32.fp8
8487
%3 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
8588

8689
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f16.fp8
87-
%5 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
90+
%4 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
91+
92+
// CHECK-GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf16.fp8
93+
%5 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
8894
tt.return
8995
}
9096
}

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3131
auto b = TritonLLVMOpBuilder(loc, rewriter);
3232
auto fp8x4VecTy = vec_ty(i8_ty, 4);
3333
Value fp8x4Vec = b.undef(fp8x4VecTy);
34-
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, b.i32_val(0));
35-
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, b.i32_val(1));
34+
auto idx0 = b.i32_val(0);
35+
auto idx1 = b.i32_val(1);
36+
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v0, idx0);
37+
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, idx1);
3638
auto i32v = b.bitcast(fp8x4Vec, i32_ty);
3739

3840
auto resType = i32_ty;
@@ -41,18 +43,22 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
4143
std::is_same_v<convertOp, ROCDL::CvtScaleF32PkF32Bf8Op>) {
4244
resType = i64_ty;
4345
dstType = f32_ty;
44-
} else {
46+
} else if constexpr (std::is_same_v<convertOp,
47+
ROCDL::CvtScaleF32PkF16Fp8Op> ||
48+
std::is_same_v<convertOp,
49+
ROCDL::CvtScaleF32PkF16Bf8Op>) {
4550
resType = i32_ty;
4651
dstType = f16_ty;
52+
} else {
53+
resType = i32_ty;
54+
dstType = bf16_ty;
4755
}
4856
Value scale = b.f32_val(1);
4957
Value select = b.false_val();
5058
auto result = rewriter.create<convertOp>(loc, resType, i32v, scale, select);
5159
auto retVecTy = vec_ty(dstType, 2);
5260
auto retVec = b.bitcast(result, retVecTy);
5361
SmallVector<Value> ret(2);
54-
auto idx0 = b.i32_val(0);
55-
auto idx1 = b.i32_val(1);
5662
ret[0] = b.extract_element(dstType, retVec, idx0);
5763
ret[1] = b.extract_element(dstType, retVec, idx1);
5864
return ret;
@@ -77,8 +83,10 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7783
} else {
7884
Type v2F16Ty = vec_ty(v0.getType(), 2);
7985
Value srcVec = b.undef(v2F16Ty);
80-
srcVec = b.insert_element(v2F16Ty, srcVec, v0, b.i32_val(0));
81-
srcVec = b.insert_element(v2F16Ty, srcVec, v1, b.i32_val(1));
86+
auto idx0 = b.i32_val(0);
87+
auto idx1 = b.i32_val(1);
88+
srcVec = b.insert_element(v2F16Ty, srcVec, v0, idx0);
89+
srcVec = b.insert_element(v2F16Ty, srcVec, v1, idx1);
8290
result = rewriter.create<convertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
8391
select);
8492
}
@@ -698,9 +706,10 @@ ConverterT Fp8E5M2FNUZ_to_Fp16(AMD::ISAFamily isaFamily) {
698706
: Fp8E5M2FNUZ_to_Fp16_SW;
699707
}
700708

701-
static SmallVector<Value> Fp8E5M2_to_Bf16(Location loc,
702-
ConversionPatternRewriter &rewriter,
703-
const SmallVector<Value> &v) {
709+
// OCP Bf8 -> Bf16
710+
static SmallVector<Value>
711+
Fp8E5M2_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
712+
const SmallVector<Value> &v) {
704713
auto b = TritonLLVMOpBuilder(loc, rewriter);
705714
auto fp8x4VecTy = vec_ty(i8_ty, 4);
706715
Value a0 = b.undef(fp8x4VecTy);
@@ -761,6 +770,19 @@ static SmallVector<Value> Fp8E5M2_to_Bf16(Location loc,
761770
b.extract_element(bf16_ty, out1, b.i32_val(1))};
762771
}
763772

773+
static SmallVector<Value>
774+
Fp8E5M2_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter,
775+
const SmallVector<Value> &v) {
776+
assert(v.size() == 2);
777+
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Bf8Op>(loc, rewriter,
778+
v[0], v[1]);
779+
}
780+
781+
ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) {
782+
return isaFamily == AMD::ISAFamily::CDNA4 ? Fp8E5M2_to_Bf16_HW
783+
: Fp8E5M2_to_Bf16_SW;
784+
}
785+
764786
// Bf16 -> OCP Bf8
765787
static SmallVector<Value>
766788
Bf16_to_Fp8E5M2_SW(Location loc, ConversionPatternRewriter &rewriter,
@@ -869,9 +891,9 @@ static SmallVector<Value> Bf16_to_Fp8E4M3FN(Location loc,
869891
}
870892

871893
// fp8e4m3fn to bf16
872-
static SmallVector<Value> Fp8E4M3FN_to_Bf16(Location loc,
873-
ConversionPatternRewriter &rewriter,
874-
const SmallVector<Value> &v) {
894+
static SmallVector<Value>
895+
Fp8E4M3FN_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
896+
const SmallVector<Value> &v) {
875897
auto b = TritonLLVMOpBuilder(loc, rewriter);
876898
auto fp8x4VecTy = vec_ty(i8_ty, 4);
877899
Value a0 = b.undef(fp8x4VecTy);
@@ -904,6 +926,19 @@ static SmallVector<Value> Fp8E4M3FN_to_Bf16(Location loc,
904926
b.extract_element(bf16_ty, out0, b.i32_val(1))};
905927
}
906928

929+
static SmallVector<Value>
930+
Fp8E4M3FN_to_Bf16_HW(Location loc, ConversionPatternRewriter &rewriter,
931+
const SmallVector<Value> &v) {
932+
assert(v.size() == 2);
933+
return cvtScalePkUpcastFromFp8<ROCDL::CvtScaleF32PkBf16Fp8Op>(loc, rewriter,
934+
v[0], v[1]);
935+
}
936+
937+
ConverterT Fp8E4M3FN_to_Bf16(AMD::ISAFamily isaFamily) {
938+
return isaFamily == AMD::ISAFamily::CDNA4 ? Fp8E4M3FN_to_Bf16_HW
939+
: Fp8E4M3FN_to_Bf16_SW;
940+
}
941+
907942
// fp8e4m3fnuz to bf16
908943
static SmallVector<Value>
909944
Fp8E4M3FNUZ_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
@@ -1130,9 +1165,10 @@ struct FpToFpOpConversion
11301165
Fp16_to_Fp8E5M2_RTNE(isaFamily)},
11311166
{{F16TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp16_to_Fp8E5M2_RTZ},
11321167
// F8 -> BF16
1133-
{{F8E5M2TyID, BF16TyID, undefRounding}, Fp8E5M2_to_Bf16},
1168+
{{F8E5M2TyID, BF16TyID, undefRounding}, Fp8E5M2_to_Bf16(isaFamily)},
11341169
{{F8E5M2FNUZTyID, BF16TyID, undefRounding}, Fp8E5M2FNUZ_to_Bf16},
1135-
{{F8E4M3FNTyID, BF16TyID, undefRounding}, Fp8E4M3FN_to_Bf16},
1170+
{{F8E4M3FNTyID, BF16TyID, undefRounding},
1171+
Fp8E4M3FN_to_Bf16(isaFamily)},
11361172
{{F8E4M3FNUZTyID, BF16TyID, undefRounding}, Fp8E4M3FNUZ_to_Bf16},
11371173
// BF16 -> F8
11381174
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
@@ -1197,16 +1233,11 @@ struct FpToFpOpConversion
11971233
}
11981234

11991235
// numElements = 4 for conversions:
1200-
// ocp bf8->bf16, or
1201-
// ocp bf8->fp32/fp16 on non-CDNA4, or
1236+
// ocp bf8->fp32/fp16/bf16 on non-CDNA4, or
12021237
// fp32/bf16/fp16->ocp bf8 on non-CDNA4
12031238
// fp32/bf16/fp16->ocp bf8 (RTZ) on CDNA4
12041239
size_t numElements = 2;
12051240
if ((llvm::isa<Float8E5M2Type>(srcElementType) &&
1206-
llvm::isa<BFloat16Type>(dstElementType)) ||
1207-
(llvm::isa<Float8E5M2Type>(srcElementType) &&
1208-
(llvm::isa<Float16Type>(dstElementType) ||
1209-
llvm::isa<Float32Type>(dstElementType)) &&
12101241
isaFamily != AMD::ISAFamily::CDNA4) ||
12111242
(llvm::isa<Float8E5M2Type>(dstElementType) &&
12121243
isaFamily != AMD::ISAFamily::CDNA4) ||

0 commit comments

Comments
 (0)