Skip to content

Commit d0b3ad8

Browse files
authored
[AMD] Use rocdl op for fp32->fp16 RTZ conversion (#6425)
Replace inline assembly with ROCDL wrappers for FFp32->Fp16 conversions to enable potential backend optimizations.
1 parent 769a82b commit d0b3ad8

File tree

5 files changed

+69
-48
lines changed

5 files changed

+69
-48
lines changed

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
3131
tt.func @f32_to_f16(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
3232
// CHECK-COUNT-8: llvm.intr.experimental.constrained.fptrunc %{{.+}} tonearest ignore : f32 to f16
3333
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
34-
// CHECK-COUNT-8: llvm.inline_asm asm_dialect {{.*}}s_setreg_imm32_b32{{.+}}v_cvt_f16_f32{{.+}}s_setreg_imm32_b32{{.+}} : (f32) -> f16
35-
34+
// CHECK-COUNT-4: rocdl.cvt.pkrtz
3635
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
3736
tt.return
3837
}
@@ -117,6 +116,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
117116
tt.func @f8_rtz(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
118117
%arg1: tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
119118
// CHECK-GFX950-NOT: rocdl.cvt.scalef32.pk.f32.bf8
119+
// CHECK-GFX950-COUNT-4: rocdl.cvt.pkrtz
120120
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
121121
// CHECK-GFX950-NOT: rocdl.cvt.scalef32.pk.f16.bf8
122122
%2 = tt.fp_to_fp %arg1, rounding = rtz : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,9 @@ static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
372372
}
373373

374374
// Convert Fp32 to OCP Bf8 on CDNA4
375-
static SmallVector<Value> Fp32_to_Fp8E5M2(Location loc,
376-
ConversionPatternRewriter &rewriter,
377-
const SmallVector<Value> &v) {
375+
static SmallVector<Value>
376+
Fp32_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter,
377+
const SmallVector<Value> &v) {
378378
assert(v.size() == 2);
379379
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkBf8F32Op>(loc, rewriter,
380380
v[0], v[1]);
@@ -575,6 +575,43 @@ ConverterT Fp8E5M2_to_Fp16(AMD::ISAFamily isaFamily) {
575575
: Fp8E5M2_to_Fp16_SW;
576576
}
577577

578+
static SmallVector<Value>
579+
convertFp32ToFp16RTZ(Location loc, ConversionPatternRewriter &rewriter,
580+
const SmallVector<Value> &v) {
581+
assert(v.size() == 2);
582+
583+
auto b = TritonLLVMOpBuilder(loc, rewriter);
584+
Type v2f16Ty = vec_ty(f16_ty, 2);
585+
586+
Value result;
587+
result = rewriter.create<ROCDL::CvtPkRtz>(loc, v2f16Ty, v[0], v[1]);
588+
SmallVector<Value> ret(2);
589+
auto idx0 = b.i32_val(0);
590+
auto idx1 = b.i32_val(1);
591+
ret[0] = b.extract_element(f16_ty, result, idx0);
592+
ret[1] = b.extract_element(f16_ty, result, idx1);
593+
return ret;
594+
}
595+
596+
static SmallVector<Value>
597+
Fp32_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
598+
const SmallVector<Value> &v) {
599+
assert(v.size() == 4);
600+
SmallVector<Value> inVals(2);
601+
inVals[0] = v[0];
602+
inVals[1] = v[1];
603+
auto f16Vec = convertFp32ToFp16RTZ(loc, rewriter, inVals);
604+
SmallVector<Value> vec(4);
605+
vec[0] = f16Vec[0];
606+
vec[1] = f16Vec[1];
607+
inVals[0] = v[2];
608+
inVals[1] = v[3];
609+
f16Vec = convertFp32ToFp16RTZ(loc, rewriter, inVals);
610+
vec[2] = f16Vec[0];
611+
vec[3] = f16Vec[1];
612+
return Fp16_to_Fp8E5M2_RTZ(loc, rewriter, vec);
613+
}
614+
578615
static Value convertBf16ToFp32(Location loc,
579616
ConversionPatternRewriter &rewriter,
580617
const Value &v) {
@@ -670,8 +707,8 @@ Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
670707
cvtPkF8ToFp32<ROCDL::CvtPkF32Bf8Op>(loc, rewriter, v[0], v[1]);
671708

672709
// Convert fp32 to fp16
673-
ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE);
674-
ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE);
710+
ret[0] = LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, ret[0]);
711+
ret[1] = LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, ret[1]);
675712

676713
return ret;
677714
}
@@ -1006,8 +1043,8 @@ Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter,
10061043
cvtPkF8ToFp32<ROCDL::CvtPkF32Fp8Op>(loc, rewriter, v[0], v[1]);
10071044

10081045
// Convert fp32 to fp16
1009-
ret[0] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[0], RoundingMode::RTNE);
1010-
ret[1] = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, ret[1], RoundingMode::RTNE);
1046+
ret[0] = LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, ret[0]);
1047+
ret[1] = LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, ret[1]);
10111048

10121049
return ret;
10131050
}
@@ -1171,11 +1208,14 @@ struct FpToFpOpConversion
11711208
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
11721209
Fp32_to_Fp8E5M2FNUZ},
11731210
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1174-
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2},
1211+
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2_RTNE},
1212+
{{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
11751213
{{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
11761214
{{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
11771215
{{F8E4M3FNTyID, F32TyID, undefRounding}, Fp8E4M3FN_to_Fp32},
11781216
{{F8E5M2TyID, F32TyID, undefRounding}, Fp8E5M2_to_Fp32},
1217+
// F32 -> F16 with RTZ
1218+
{{F32TyID, F16TyID, RoundingMode::RTZ}, convertFp32ToFp16RTZ},
11791219
};
11801220
std::tuple<TypeID, TypeID, RoundingMode> key = {
11811221
srcTy.getTypeID(), dstTy.getTypeID(),
@@ -1195,14 +1235,14 @@ struct FpToFpOpConversion
11951235
auto dstElementType = getElementType(op.getResult());
11961236

11971237
auto roundingMode = op.getRounding();
1198-
if (srcElementType.isF32() && dstElementType.isF16()) {
1238+
if (srcElementType.isF32() && dstElementType.isF16() &&
1239+
roundingMode.value() == RoundingMode::RTNE) {
11991240
assert(roundingMode.has_value() &&
12001241
"rounding mode must be specified for fp32->fp16 conversion");
12011242
SmallVector<Value> outVals;
12021243
outVals.reserve(operands[0].size());
12031244
for (Value v : operands[0]) {
1204-
outVals.push_back(
1205-
LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v, roundingMode.value()));
1245+
outVals.push_back(LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, v));
12061246
}
12071247
return outVals;
12081248
}
@@ -1234,18 +1274,19 @@ struct FpToFpOpConversion
12341274
numElements = 4;
12351275
}
12361276

1237-
// f32->fp8/bf8, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1238-
// done in two steps: f32->fp16 with rtne and fp16->fp8/bf8 with rtne
1277+
// f32->fp8/bf8 with rtne, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8
1278+
// on CDNA4, is done in two steps: f32->fp16 with rtne and fp16->fp8/bf8
1279+
// with rtne
12391280
bool useFP16IntermediateSrc =
1240-
srcElementType.isF32() &&
1281+
srcElementType.isF32() && !dstElementType.isF16() &&
1282+
roundingMode == RoundingMode::RTNE &&
12411283
!(isaFamily == AMD::ISAFamily::CDNA4 &&
1242-
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType)) &&
1243-
roundingMode == RoundingMode::RTNE) &&
1284+
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) &&
12441285
!(isaFamily == AMD::ISAFamily::CDNA3 &&
12451286
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
12461287

1247-
// fp8/bf8->f32, if not nanoo fp8/bf8 on CDNA3 or ocp fp8/bf8 on CDNA4, is
1248-
// done in two steps: fp8/bf8->fp16 and fp16->fp32
1288+
// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
1289+
// is done in two steps: fp8/bf8->fp16 and fp16->fp32
12491290
bool isDstFP32 = dstElementType.isF32();
12501291
bool useFP16IntermediateDst =
12511292
(isDstFP32 &&
@@ -1277,8 +1318,8 @@ struct FpToFpOpConversion
12771318
}
12781319
if (useFP16IntermediateSrc)
12791320
for (Value &v : inVals)
1280-
v = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, v,
1281-
roundingMode.value_or(RoundingMode::RTNE));
1321+
v = LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, v);
1322+
12821323
inVals.resize(numElements, b.undef(typeConverter->convertType(srcType)));
12831324
SmallVector<Value> outVals;
12841325
if (srcType != dstType) {

third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,7 @@ Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, Value scale,
198198
auto b = TritonLLVMOpBuilder(loc, rewriter);
199199
Value scaleF32 =
200200
b.bitcast(b.shl(b.zext(i32_ty, scale), b.i32_val(23)), f32_ty);
201-
Value scaleF16 =
202-
LLVM::AMD::cvtFp32ToFp16(loc, rewriter, scaleF32, RoundingMode::RTNE);
201+
Value scaleF16 = LLVM::AMD::cvtFp32ToFp16RTNE(loc, rewriter, scaleF32);
203202
Value mulF16 = b.fmul(v, scaleF16);
204203
if (fastMath)
205204
return mulF16;

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -496,28 +496,10 @@ int32_t getCtrlBitsForCacheModifierOnTarget(
496496
}
497497
}
498498

499-
Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v,
500-
triton::RoundingMode rounding) {
501-
if (rounding == triton::RoundingMode::RTNE) {
502-
LLVM::RoundingMode rm = LLVM::RoundingMode::NearestTiesToEven;
503-
return rewriter.create<LLVM::ConstrainedFPTruncIntr>(
504-
loc, f16_ty, v, rm, LLVM::FPExceptionBehavior::Ignore);
505-
}
506-
507-
// TODO: Figure out the test failure with RTZ LLVM::ConstrainedFPTruncIntr and
508-
// switch to not use inline assembly too.
509-
assert(rounding == triton::RoundingMode::RTZ);
510-
GCNBuilder builder;
511-
512-
auto &cvt = *builder.create("v_cvt_f16_f32");
513-
auto res = builder.newOperand("=v");
514-
auto operand = builder.newOperand(v, "v");
515-
auto &setRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0xc");
516-
setRTZ();
517-
cvt(res, operand);
518-
auto &resetRTZ = *builder.create("s_setreg_imm32_b32 0x1801, 0x0");
519-
resetRTZ();
520-
return builder.launch(rewriter, loc, f16_ty, false);
499+
Value cvtFp32ToFp16RTNE(Location loc, RewriterBase &rewriter, const Value &v) {
500+
LLVM::RoundingMode rm = LLVM::RoundingMode::NearestTiesToEven;
501+
return rewriter.create<LLVM::ConstrainedFPTruncIntr>(
502+
loc, f16_ty, v, rm, LLVM::FPExceptionBehavior::Ignore);
521503
}
522504

523505
Type getPointerTypeWithShape(Value basePtr, Value offset) {

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ getCtrlBitsForCacheModifierOnTarget(triton::CacheModifier, bool,
6262
int32_t getCtrlBitsForBufferAtomicsOnGFX_942_950(bool setSC0, bool setSC1,
6363
bool setNT);
6464

65-
Value cvtFp32ToFp16(Location loc, RewriterBase &rewriter, const Value &v,
66-
triton::RoundingMode rounding);
65+
Value cvtFp32ToFp16RTNE(Location loc, RewriterBase &rewriter, const Value &v);
6766

6867
// Return a tensor of pointers with the same type of `basePtr` and the same
6968
// shape of `offset`

0 commit comments

Comments
 (0)