Skip to content

Commit 2e8acb1

Browse files
committed
change signature to take in <4 x float>
1 parent 398398b commit 2e8acb1

File tree

8 files changed

+223
-153
lines changed

8 files changed

+223
-153
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -624,10 +624,10 @@ def __nvvm_e4m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
624624
def __nvvm_e5m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
625625
def __nvvm_e5m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM_89, PTX81>;
626626

627-
def __nvvm_ff_to_e4m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
628-
def __nvvm_ff_to_e4m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
629-
def __nvvm_ff_to_e5m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
630-
def __nvvm_ff_to_e5m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
627+
def __nvvm_f32x4_to_e4m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
628+
def __nvvm_f32x4_to_e4m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
629+
def __nvvm_f32x4_to_e5m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
630+
def __nvvm_f32x4_to_e5m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
631631

632632
def __nvvm_ff_to_e2m3x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
633633
def __nvvm_ff_to_e2m3x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
@@ -639,19 +639,19 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
639639
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
640640
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
641641

642-
def __nvvm_ff_to_e2m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
643-
def __nvvm_ff_to_e2m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
644-
def __nvvm_ff_to_e3m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
645-
def __nvvm_ff_to_e3m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
642+
def __nvvm_f32x4_to_e2m3x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
643+
def __nvvm_f32x4_to_e2m3x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
644+
def __nvvm_f32x4_to_e3m2x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
645+
def __nvvm_f32x4_to_e3m2x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"_Vector<4, char>(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
646646

647647
def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
648648
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
649649

650650
def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
651651
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
652652

653-
def __nvvm_ff_to_e2m1x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
654-
def __nvvm_ff_to_e2m1x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float, float, float, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
653+
def __nvvm_f32x4_to_e2m1x4_rs_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
654+
def __nvvm_f32x4_to_e2m1x4_rs_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(_Vector<4, float>, uint32_t)", SM<"100a", [SM_103a]>, PTX87>;
655655

656656
def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
657657
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,45 +1244,45 @@ __device__ void nvvm_cvt_sm100a_sm103a() {
12441244
// CHECK_PTX87_SM103a: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, i32 0)
12451245
__nvvm_ff2bf16x2_rs_relu_satfinite(1.0f, 1.0f, 0);
12461246

1247-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1248-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1249-
__nvvm_ff_to_e4m3x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1247+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1248+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1249+
__nvvm_f32x4_to_e4m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12501250

1251-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1252-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e4m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1253-
__nvvm_ff_to_e4m3x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1251+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1252+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1253+
__nvvm_f32x4_to_e4m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12541254

1255-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1256-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1257-
__nvvm_ff_to_e5m2x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1255+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1256+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1257+
__nvvm_f32x4_to_e5m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12581258

1259-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1260-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e5m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1261-
__nvvm_ff_to_e5m2x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1259+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1260+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1261+
__nvvm_f32x4_to_e5m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12621262

1263-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1264-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1265-
__nvvm_ff_to_e2m3x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1263+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1264+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1265+
__nvvm_f32x4_to_e2m3x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12661266

1267-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1268-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e2m3x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1269-
__nvvm_ff_to_e2m3x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1267+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1268+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1269+
__nvvm_f32x4_to_e2m3x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12701270

1271-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1272-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1273-
__nvvm_ff_to_e3m2x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1271+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1272+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1273+
__nvvm_f32x4_to_e3m2x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12741274

1275-
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1276-
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.ff.to.e3m2x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1277-
__nvvm_ff_to_e3m2x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1275+
// CHECK_PTX87_SM100a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1276+
// CHECK_PTX87_SM103a: call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1277+
__nvvm_f32x4_to_e3m2x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12781278

1279-
// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1280-
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1281-
__nvvm_ff_to_e2m1x4_rs_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1279+
// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1280+
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1281+
__nvvm_f32x4_to_e2m1x4_rs_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12821282

1283-
// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1284-
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.ff.to.e2m1x4.rs.relu.satfinite(float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, i32 0)
1285-
__nvvm_ff_to_e2m1x4_rs_relu_satfinite(1.0f, 1.0f, 1.0f, 1.0f, 0);
1283+
// CHECK_PTX87_SM100a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1284+
// CHECK_PTX87_SM103a: call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> splat (float 1.000000e+00), i32 0)
1285+
__nvvm_f32x4_to_e2m1x4_rs_relu_satfinite({1.0f, 1.0f, 1.0f, 1.0f}, 0);
12861286
#endif
12871287
}
12881288

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,8 +1530,8 @@ let TargetPrefix = "nvvm" in {
15301530
// RS rounding mode conversions for f8x4 types
15311531
foreach type = ["e4m3x4", "e5m2x4"] in {
15321532
foreach relu = ["", "_relu"] in {
1533-
def int_nvvm_ff_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1534-
PureIntrinsic<[llvm_v4i8_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1533+
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1534+
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
15351535
}
15361536
}
15371537

@@ -1546,8 +1546,8 @@ let TargetPrefix = "nvvm" in {
15461546

15471547
// RS rounding mode conversions for f4x4 type
15481548
foreach relu = ["", "_relu"] in {
1549-
def int_nvvm_ff_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
1550-
PureIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1549+
def int_nvvm_f32x4_to_e2m1x4_rs # relu # _satfinite : NVVMBuiltin,
1550+
PureIntrinsic<[llvm_i16_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
15511551
}
15521552

15531553
// FP6 conversions.
@@ -1564,8 +1564,8 @@ let TargetPrefix = "nvvm" in {
15641564
// RS rounding mode conversions for f6x4 types
15651565
foreach type = ["e2m3x4", "e3m2x4"] in {
15661566
foreach relu = ["", "_relu"] in {
1567-
def int_nvvm_ff_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1568-
PureIntrinsic<[llvm_v4i8_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_float_ty, llvm_i32_ty]>;
1567+
def int_nvvm_f32x4_to_ # type # _rs # relu # _satfinite : NVVMBuiltin,
1568+
PureIntrinsic<[llvm_v4i8_ty], [llvm_v4f32_ty, llvm_i32_ty]>;
15691569
}
15701570
}
15711571

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,9 +1077,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10771077
// Enable custom lowering for the following:
10781078
// * MVT::i128 - clusterlaunchcontrol
10791079
// * MVT::i32 - prmt
1080+
// * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
10801081
// * MVT::Other - internal.addrspace.wrap
1081-
setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
1082-
Custom);
1082+
setOperationAction(ISD::INTRINSIC_WO_CHAIN,
1083+
{MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
10831084
}
10841085

10851086
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1162,6 +1163,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11621163
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
11631164
MAKE_CASE(
11641165
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
1166+
MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
1167+
MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
1168+
MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
1169+
MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
1170+
MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
11651171
}
11661172
return nullptr;
11671173

@@ -2839,6 +2845,69 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
28392845
{TryCancelResponse0, TryCancelResponse1});
28402846
}
28412847

2848+
bool isCvtRSReluIntrinsic(Intrinsic::ID ID) {
2849+
switch (ID) {
2850+
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2851+
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2852+
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2853+
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2854+
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2855+
return true;
2856+
default:
2857+
return false;
2858+
}
2859+
}
2860+
2861+
static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
2862+
SDNode *N = Op.getNode();
2863+
SDLoc DL(N);
2864+
SDValue F32Vec = N->getOperand(1);
2865+
SDValue RBits = N->getOperand(2);
2866+
2867+
unsigned IntrinsicID = N->getConstantOperandVal(0);
2868+
2869+
uint32_t CvtModeFlag = NVPTX::PTXCvtMode::CvtMode::RS;
2870+
if (isCvtRSReluIntrinsic(IntrinsicID))
2871+
CvtModeFlag |= NVPTX::PTXCvtMode::CvtMode::RELU_FLAG;
2872+
2873+
SDValue Float1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2874+
DAG.getIntPtrConstant(0, DL));
2875+
SDValue Float2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2876+
DAG.getIntPtrConstant(1, DL));
2877+
SDValue Float3 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2878+
DAG.getIntPtrConstant(2, DL));
2879+
SDValue Float4 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
2880+
DAG.getIntPtrConstant(3, DL));
2881+
2882+
auto OpSignature =
2883+
[&]() -> std::pair<NVPTXISD::NodeType, MVT::SimpleValueType> {
2884+
switch (IntrinsicID) {
2885+
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2886+
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2887+
return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8};
2888+
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2889+
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2890+
return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8};
2891+
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2892+
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2893+
return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8};
2894+
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2895+
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2896+
return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8};
2897+
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2898+
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2899+
return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16};
2900+
default:
2901+
llvm_unreachable("unsupported/unhandled intrinsic");
2902+
}
2903+
}();
2904+
2905+
SDValue Ops[] = {Float1, Float2, Float3,
2906+
Float4, RBits, DAG.getConstant(CvtModeFlag, DL, MVT::i32)};
2907+
2908+
return DAG.getNode(OpSignature.first, DL, OpSignature.second, Ops);
2909+
}
2910+
28422911
static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
28432912
const unsigned Mode = [&]() {
28442913
switch (Op->getConstantOperandVal(0)) {
@@ -2886,6 +2955,17 @@ static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
28862955
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
28872956
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
28882957
return LowerClusterLaunchControlQueryCancel(Op, DAG);
2958+
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
2959+
case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
2960+
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
2961+
case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
2962+
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
2963+
case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
2964+
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
2965+
case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
2966+
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
2967+
case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
2968+
return lowerCvtRSIntrinsics(Op, DAG);
28892969
}
28902970
}
28912971

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ enum NodeType : unsigned {
7979
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
8080
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
8181
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
82+
CVT_E4M3X4_F32X4_RS_SF,
83+
CVT_E5M2X4_F32X4_RS_SF,
84+
CVT_E2M3X4_F32X4_RS_SF,
85+
CVT_E3M2X4_F32X4_RS_SF,
86+
CVT_E2M1X4_F32X4_RS_SF,
8287

8388
FIRST_MEMORY_OPCODE,
8489

0 commit comments

Comments
 (0)