diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index ae04a130bc825..00a76018d8415 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1438,6 +1438,16 @@ let TargetPrefix = "nvvm" in { def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">, Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f2tf32_rna_satfinite : ClangBuiltin<"__nvvm_f2tf32_rna_satfinite">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f2tf32_rn : ClangBuiltin<"__nvvm_f2tf32_rn">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f2tf32_rn_relu : ClangBuiltin<"__nvvm_f2tf32_rn_relu">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f2tf32_rz : ClangBuiltin<"__nvvm_f2tf32_rz">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f2tf32_rz_relu : ClangBuiltin<"__nvvm_f2tf32_rz_relu">, + Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">, Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index c3e72d6ce3a3f..6a95d9ebef6c7 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -725,6 +725,23 @@ let hasSideEffects = false in { def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">; def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">; + + // Float to TF32 conversions + multiclass CVT_TO_TF32 Preds = [hasPTX<78>, hasSM<90>]> { + defvar Intr = !cast("int_nvvm_f2tf32_" # !subst(".", "_", Modifier)); + + def NAME : NVPTXInst<(outs Int32Regs:$dst), (ins Float32Regs:$src), + "cvt." # Modifier # ".tf32.f32 \t$dst, $src;", + [(set i32:$dst, (Intr f32:$src))]>, + Requires; + } + + defm CVT_to_tf32_rn : CVT_TO_TF32<"rn">; + defm CVT_to_tf32_rz : CVT_TO_TF32<"rz">; + defm CVT_to_tf32_rn_relu : CVT_TO_TF32<"rn.relu">; + defm CVT_to_tf32_rz_relu : CVT_TO_TF32<"rz.relu">; + defm CVT_to_tf32_rna : CVT_TO_TF32<"rna", [hasPTX<70>, hasSM<80>]>; + defm CVT_to_tf32_rna_satf : CVT_TO_TF32<"rna.satfinite", [hasPTX<81>, hasSM<89>]>; } def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{ diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 22339ebc5484f..4f144cc641080 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1722,11 +1722,6 @@ def : Pat<(int_nvvm_f2bf16_rz f32:$a), def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a), (CVT_bf16_f32 $a, CvtRZ_RELU)>; -def CVT_tf32_f32 : - NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a), - "cvt.rna.tf32.f32 \t$dest, $a;", - [(set i32:$dest, (int_nvvm_f2tf32_rna f32:$a))]>; - def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};", Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>; diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll index 5d0576aebbe08..30fd76f5a31c2 100644 --- a/llvm/test/CodeGen/NVPTX/convert-sm89.ll +++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll @@ -84,3 +84,10 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) { %val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in); ret <2 x half> %val } + +; CHECK-LABEL: cvt_rna_satfinite_tf32_f32 +define i32 @cvt_rna_satfinite_tf32_f32(float %f1) { +; CHECK: cvt.rna.satfinite.tf32.f32 + %val = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %f1) + ret i32 %val +} diff --git a/llvm/test/CodeGen/NVPTX/convert-sm90.ll b/llvm/test/CodeGen/NVPTX/convert-sm90.ll new file mode 100644 index 0000000000000..5f610e0e91f88 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/convert-sm90.ll @@ -0,0 +1,68 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s +; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %} + +declare i32 @llvm.nvvm.f2tf32.rn(float %f1) +declare i32 @llvm.nvvm.f2tf32.rn.relu(float %f1) +declare i32 @llvm.nvvm.f2tf32.rz(float %f1) +declare i32 @llvm.nvvm.f2tf32.rz.relu(float %f1) + +define i32 @cvt_rn_tf32_f32(float %f1) { +; CHECK-LABEL: cvt_rn_tf32_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .f32 %f<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0]; +; CHECK-NEXT: cvt.rn.tf32.f32 %r1, %f1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i32 @llvm.nvvm.f2tf32.rn(float %f1) + ret i32 %val +} + +define i32 @cvt_rn_relu_tf32_f32(float %f1) { +; CHECK-LABEL: cvt_rn_relu_tf32_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .f32 %f<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0]; +; CHECK-NEXT: cvt.rn.relu.tf32.f32 %r1, %f1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i32 @llvm.nvvm.f2tf32.rn.relu(float %f1) + ret i32 %val +} + +define i32 @cvt_rz_tf32_f32(float %f1) { +; CHECK-LABEL: cvt_rz_tf32_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .f32 %f<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0]; +; CHECK-NEXT: cvt.rz.tf32.f32 %r1, %f1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i32 @llvm.nvvm.f2tf32.rz(float %f1) + ret i32 %val +} + +define i32 @cvt_rz_relu_tf32_f32(float %f1) { +; CHECK-LABEL: cvt_rz_relu_tf32_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .f32 %f<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0]; +; CHECK-NEXT: cvt.rz.relu.tf32.f32 %r1, %f1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i32 @llvm.nvvm.f2tf32.rz.relu(float %f1) + ret i32 %val +}