Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.td
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,21 @@ def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf1
def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
Expand Down Expand Up @@ -447,19 +455,51 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
// Add

def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_sat_f : NVPTXBuiltin<"float(float, float)">;

def __nvvm_add_rn_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;

// Sub

def __nvvm_sub_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rn_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_sat_f : NVPTXBuiltin<"float(float, float)">;

def __nvvm_sub_rn_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_sub_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_sub_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_sub_rp_d : NVPTXBuiltin<"double(double, double)">;

// Convert

def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
Expand Down
92 changes: 92 additions & 0 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -1466,3 +1466,95 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}

// CHECK-LABEL: nvvm_add_sub_fma_f32_sat
__device__ void nvvm_add_sub_fma_f32_sat() {
// CHECK: call float @llvm.nvvm.add.rn.sat.f
__nvvm_add_rn_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rn.ftz.sat.f
__nvvm_add_rn_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rz.sat.f
__nvvm_add_rz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rz.ftz.sat.f
__nvvm_add_rz_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rm.sat.f
__nvvm_add_rm_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rm.ftz.sat.f
__nvvm_add_rm_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rp.sat.f
__nvvm_add_rp_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rp.ftz.sat.f
__nvvm_add_rp_ftz_sat_f(1.0f, 2.0f);

// CHECK: call float @llvm.nvvm.sub.rn.sat.f
__nvvm_sub_rn_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rn.ftz.sat.f
__nvvm_sub_rn_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.sat.f
__nvvm_sub_rz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.ftz.sat.f
__nvvm_sub_rz_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.sat.f
__nvvm_sub_rm_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.ftz.sat.f
__nvvm_sub_rm_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.sat.f
__nvvm_sub_rp_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.ftz.sat.f
__nvvm_sub_rp_ftz_sat_f(1.0f, 2.0f);

// CHECK: call float @llvm.nvvm.fma.rn.sat.f
__nvvm_fma_rn_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rn.ftz.sat.f
__nvvm_fma_rn_ftz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rz.sat.f
__nvvm_fma_rz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rz.ftz.sat.f
__nvvm_fma_rz_ftz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rm.sat.f
__nvvm_fma_rm_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rm.ftz.sat.f
__nvvm_fma_rm_ftz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rp.sat.f
__nvvm_fma_rp_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rp.ftz.sat.f
__nvvm_fma_rp_ftz_sat_f(1.0f, 2.0f, 3.0f);

// CHECK: ret void
}

// CHECK-LABEL: nvvm_sub_f32
__device__ void nvvm_sub_f32() {
// CHECK: call float @llvm.nvvm.sub.rn.f
__nvvm_sub_rn_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rn.ftz.f
__nvvm_sub_rn_ftz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.f
__nvvm_sub_rz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.ftz.f
__nvvm_sub_rz_ftz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.f
__nvvm_sub_rm_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.ftz.f
__nvvm_sub_rm_ftz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.f
__nvvm_sub_rp_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.ftz.f
__nvvm_sub_rp_ftz_f(1.0f, 2.0f);

// CHECK: ret void
}

// CHECK-LABEL: nvvm_sub_f64
__device__ void nvvm_sub_f64() {
// CHECK: call double @llvm.nvvm.sub.rn.d
__nvvm_sub_rn_d(1.0f, 2.0f);
// CHECK: call double @llvm.nvvm.sub.rz.d
__nvvm_sub_rz_d(1.0f, 2.0f);
// CHECK: call double @llvm.nvvm.sub.rm.d
__nvvm_sub_rm_d(1.0f, 2.0f);
// CHECK: call double @llvm.nvvm.sub.rp.d
__nvvm_sub_rp_d(1.0f, 2.0f);

// CHECK: ret void
}
40 changes: 30 additions & 10 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1376,13 +1376,15 @@ let TargetPrefix = "nvvm" in {
} // ftz
} // variant

foreach rnd = ["rn", "rz", "rm", "rp"] in {
foreach ftz = ["", "_ftz"] in
def int_nvvm_fma_ # rnd # ftz # _f : NVVMBuiltin,
PureIntrinsic<[llvm_float_ty],
[llvm_float_ty, llvm_float_ty, llvm_float_ty]>;
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in
def int_nvvm_fma # rnd # ftz # sat # _f : NVVMBuiltin,
PureIntrinsic<[llvm_float_ty],
[llvm_float_ty, llvm_float_ty, llvm_float_ty]>;
}

def int_nvvm_fma_ # rnd # _d : NVVMBuiltin,
def int_nvvm_fma # rnd # _d : NVVMBuiltin,
PureIntrinsic<[llvm_double_ty],
[llvm_double_ty, llvm_double_ty, llvm_double_ty]>;
}
Expand Down Expand Up @@ -1443,12 +1445,30 @@ let TargetPrefix = "nvvm" in {
// Add
//
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["rn", "rz", "rm", "rp"] in {
foreach ftz = ["", "_ftz"] in
def int_nvvm_add_ # rnd # ftz # _f : NVVMBuiltin,
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in
def int_nvvm_add # rnd # ftz # sat # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
}

def int_nvvm_add # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
}

//
// Sub
//
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
}

def int_nvvm_add_ # rnd # _d : NVVMBuiltin,
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
}
Expand Down
86 changes: 86 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1637,13 +1637,21 @@ multiclass FMA_INST {
FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, B64>,

FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, B32>,
FMA_TUPLE<"_rn_ftz_sat_f32", int_nvvm_fma_rn_ftz_sat_f, B32>,
FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, B32>,
FMA_TUPLE<"_rn_sat_f32", int_nvvm_fma_rn_sat_f, B32>,
FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, B32>,
FMA_TUPLE<"_rz_ftz_sat_f32", int_nvvm_fma_rz_ftz_sat_f, B32>,
FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, B32>,
FMA_TUPLE<"_rz_sat_f32", int_nvvm_fma_rz_sat_f, B32>,
FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, B32>,
FMA_TUPLE<"_rm_sat_f32", int_nvvm_fma_rm_sat_f, B32>,
FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, B32>,
FMA_TUPLE<"_rm_ftz_sat_f32", int_nvvm_fma_rm_ftz_sat_f, B32>,
FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, B32>,
FMA_TUPLE<"_rp_sat_f32", int_nvvm_fma_rp_sat_f, B32>,
FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, B32>,
FMA_TUPLE<"_rp_ftz_sat_f32", int_nvvm_fma_rp_ftz_sat_f, B32>,

FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, B16, [hasPTX<42>, hasSM<53>]>,
FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, B16,
Expand Down Expand Up @@ -1694,6 +1702,22 @@ multiclass FMA_INST {

defm INT_NVVM_FMA : FMA_INST;

foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
foreach sat = ["", "_SAT"] in {
foreach type = ["F16", "BF16"] in {
def INT_NVVM_FMA # rnd # sat # _F32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
!tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)),
[(set f32:$dst,
(!cast<Intrinsic>(!tolower("int_nvvm_fma" # rnd # sat # "_f"))
(f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
(f32 (fpextend !cast<ValueType>(!tolower(type)):$b)),
f32:$c))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}

//
// Rcp
//
Expand Down Expand Up @@ -1793,19 +1817,81 @@ let Predicates = [doRsqrtOpt] in {
//

def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>;
def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
def INT_NVVM_ADD_RN_SAT_F : F_MATH_2<"add.rn.sat.f32", B32, B32, B32, int_nvvm_add_rn_sat_f>;
def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>;
def INT_NVVM_ADD_RZ_SAT_FTZ_F : F_MATH_2<"add.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_sat_f>;
def INT_NVVM_ADD_RZ_F : F_MATH_2<"add.rz.f32", B32, B32, B32, int_nvvm_add_rz_f>;
def INT_NVVM_ADD_RZ_SAT_F : F_MATH_2<"add.rz.sat.f32", B32, B32, B32, int_nvvm_add_rz_sat_f>;
def INT_NVVM_ADD_RM_FTZ_F : F_MATH_2<"add.rm.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_f>;
def INT_NVVM_ADD_RM_SAT_FTZ_F : F_MATH_2<"add.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_sat_f>;
def INT_NVVM_ADD_RM_F : F_MATH_2<"add.rm.f32", B32, B32, B32, int_nvvm_add_rm_f>;
def INT_NVVM_ADD_RM_SAT_F : F_MATH_2<"add.rm.sat.f32", B32, B32, B32, int_nvvm_add_rm_sat_f>;
def INT_NVVM_ADD_RP_FTZ_F : F_MATH_2<"add.rp.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_f>;
def INT_NVVM_ADD_RP_SAT_FTZ_F : F_MATH_2<"add.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_sat_f>;
def INT_NVVM_ADD_RP_F : F_MATH_2<"add.rp.f32", B32, B32, B32, int_nvvm_add_rp_f>;
def INT_NVVM_ADD_RP_SAT_F : F_MATH_2<"add.rp.sat.f32", B32, B32, B32, int_nvvm_add_rp_sat_f>;

def INT_NVVM_ADD_RN_D : F_MATH_2<"add.rn.f64", B64, B64, B64, int_nvvm_add_rn_d>;
def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>;
def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;

foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
foreach sat = ["", "_SAT"] in {
foreach type = ["F16", "BF16"] in {
def INT_NVVM_ADD # rnd # sat # _F32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the mixed FP instructions, should we also add folds for the generic fadd/fsub/fma instructions in LLVM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I've added patterns to fold fadd, fsub, and llvm.fma.32 to the mixed precision instructions when ftz isn't present. Please take a look, thanks!

!tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)),
[(set f32:$dst,
(!cast<Intrinsic>(!tolower("int_nvvm_add" # rnd # sat # "_f"))
(f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}

// Sub
//

def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>;
def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>;
def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>;
def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>;
def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>;
def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>;
def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>;
def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>;
def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>;
def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>;
def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>;
def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>;
def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>;
def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>;
def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>;
def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>;

def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>;
def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>;
def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>;
def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>;

foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in {
foreach sat = ["", "_SAT"] in {
foreach type = ["F16", "BF16"] in {
def INT_NVVM_SUB # rnd # sat # _F32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)),
[(set f32:$dst,
(!cast<Intrinsic>(!tolower("int_nvvm_sub" # rnd # sat # "_f"))
(f32 (fpextend !cast<ValueType>(!tolower(type)):$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}
//
// BFIND
//
Expand Down
Loading