Skip to content

Commit c5fed8e

Browse files
authored
[AMD] Fix packed f32 to fp16 cast for single value (triton-lang#6545)
We need to check whether we only have a single value before assuming a size-2 vector and use packed version.
1 parent a8e5788 commit c5fed8e

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck --check-prefix=GFX942 %s
2-
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefix=GFX950 %s
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck --check-prefixes=COMMON,GFX942 %s
2+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck --check-prefixes=COMMON,GFX950 %s
33

44
// CHECK-LABEL: f16_to_f32
55
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -32,15 +32,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
3232
// GFX942-COUNT-8: llvm.fptrunc %{{.+}} : f32 to f16
3333
// GFX950-COUNT-4: llvm.fptrunc %{{.+}} : vector<2xf32> to vector<2xf16>
3434
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
35-
// GFX942-COUNT-4: rocdl.cvt.pkrtz
36-
// GFX950-COUNT-4: rocdl.cvt.pkrtz
35+
// COMMON-COUNT-4: rocdl.cvt.pkrtz
3736
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
3837
tt.return
3938
}
4039
}
4140

4241
// -----
4342

43+
// CHECK-LABEL: f32_to_f16_single_value
44+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
45+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
46+
tt.func @f32_to_f16_single_value(%arg0: tensor<1x128xf32, #blocked>) {
47+
// COMMON: llvm.fptrunc %{{.+}} : f32 to f16
48+
// COMMON-NOT: llvm.fptrunc
49+
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<1x128xf32, #blocked> -> tensor<1x128xf16, #blocked>
50+
// COMMON: rocdl.cvt.pkrtz
51+
// COMMON-NOT: rocdl.cvt.pkrtz
52+
%1 = tt.fp_to_fp %arg0, rounding = rtz : tensor<1x128xf32, #blocked> -> tensor<1x128xf16, #blocked>
53+
tt.return
54+
}
55+
}
56+
57+
// -----
58+
4459
// CHECK-LABEL: downcast_to_f8
4560
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
4661
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,17 +596,20 @@ convertFp32ToFp16RTZ(Location loc, ConversionPatternRewriter &rewriter,
596596
// Fp32->Fp16/Bf16 (RTNE) in GFX950
597597
static SmallVector<Value>
598598
convertFp32ToFp16RTNE(Location loc, ConversionPatternRewriter &rewriter,
599-
const SmallVector<Value> &v, Type outElemTy) {
600-
assert(v.size() == 2);
599+
ArrayRef<Value> v, Type outElemTy) {
601600
auto b = TritonLLVMOpBuilder(loc, rewriter);
601+
if (v.size() == 1)
602+
return {b.fptrunc(outElemTy, v.front())};
603+
604+
assert(v.size() == 2);
602605
auto inVecTy = vec_ty(f32_ty, 2);
603606
auto retVecTy = vec_ty(outElemTy, 2);
604607
Value inVec = b.undef(inVecTy);
605608
auto idx0 = b.i32_val(0);
606609
auto idx1 = b.i32_val(1);
607610
inVec = b.insert_element(inVecTy, inVec, v[0], idx0);
608611
inVec = b.insert_element(inVecTy, inVec, v[1], idx1);
609-
Value retVec = rewriter.create<LLVM::FPTruncOp>(loc, retVecTy, inVec);
612+
Value retVec = b.fptrunc(retVecTy, inVec);
610613
SmallVector<Value> ret(2);
611614
ret[0] = b.extract_element(outElemTy, retVec, idx0);
612615
ret[1] = b.extract_element(outElemTy, retVec, idx1);
@@ -680,6 +683,7 @@ static SmallVector<Value> Fp32_to_F16_RTNE(Location loc,
680683
Type inElemTy, Type outElemTy,
681684
MultipleOperandsRange operands,
682685
AMD::ISAFamily isaFamily) {
686+
// For CDNA4 we can potentially use packed v_cvt_pk_[b]f16_f32 instructions.
683687
if (isaFamily == AMD::ISAFamily::CDNA4) {
684688
SmallVector<Value> inVals;
685689
size_t numElem = std::min(size_t(2), operands.size());

0 commit comments

Comments
 (0)