Skip to content

Commit e6aa86c

Browse files
authored
[BACKEND] Fix fp16 to fp32 conversion (#7585)
Fixes #6698
1 parent ca0fe1b commit e6aa86c

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,3 +2583,16 @@ tt.func private @arith_constant_array() {
25832583
tt.return
25842584
}
25852585
}
2586+
2587+
// -----
2588+
2589+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
2590+
2591+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:75", "ttg.threads-per-warp" = 32 : i32} {
2592+
// CHECK-LABEL: fp16_to_fp32
2593+
tt.func public @fp16_to_fp32(%arg0 : tensor<256xf16, #blocked>) {
2594+
// CHECK: llvm.fpext %{{.*}} : f16 to f32
2595+
%0 = tt.fp_to_fp %arg0 : tensor<256xf16, #blocked> -> tensor<256xf32, #blocked>
2596+
tt.return
2597+
}
2598+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,12 @@ struct FpToFpOpConversion
489489
}
490490
}
491491

492+
if (srcElementType.isF16() && dstElementType.isF32()) {
493+
return llvm::to_vector(llvm::map_range(operands[0], [&](Value v) {
494+
return convertFp16ToFp32(loc, rewriter, v);
495+
}));
496+
}
497+
492498
if (srcElementType.isF32() && dstElementType.isF16()) {
493499
assert(roundingMode.has_value() &&
494500
"rounding mode must be specified for fp32->fp16 conversion");

0 commit comments

Comments
 (0)