Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.td
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,21 @@ def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf1
def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>;
def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rm_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_ftz_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rp_sat_f : NVPTXBuiltin<"float(float, float, float)">;
def __nvvm_fma_rn_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">;
def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">;
Expand Down Expand Up @@ -447,19 +455,51 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">;
// Add

def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_add_rp_sat_f : NVPTXBuiltin<"float(float, float)">;

def __nvvm_add_rn_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">;

// Sub

def __nvvm_sub_rn_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rn_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rn_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rm_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_ftz_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_f : NVPTXBuiltin<"float(float, float)">;
def __nvvm_sub_rp_sat_f : NVPTXBuiltin<"float(float, float)">;

def __nvvm_sub_rn_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_sub_rz_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_sub_rm_d : NVPTXBuiltin<"double(double, double)">;
def __nvvm_sub_rp_d : NVPTXBuiltin<"double(double, double)">;

// Convert

def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">;
Expand Down
92 changes: 92 additions & 0 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -1466,3 +1466,95 @@ __device__ void nvvm_min_max_sm86() {
#endif
// CHECK: ret void
}

// CHECK-LABEL: nvvm_add_sub_fma_f32_sat
__device__ void nvvm_add_sub_fma_f32_sat() {
// CHECK: call float @llvm.nvvm.add.rn.sat.f
__nvvm_add_rn_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rn.ftz.sat.f
__nvvm_add_rn_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rz.sat.f
__nvvm_add_rz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rz.ftz.sat.f
__nvvm_add_rz_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rm.sat.f
__nvvm_add_rm_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rm.ftz.sat.f
__nvvm_add_rm_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rp.sat.f
__nvvm_add_rp_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.add.rp.ftz.sat.f
__nvvm_add_rp_ftz_sat_f(1.0f, 2.0f);

// CHECK: call float @llvm.nvvm.sub.rn.sat.f
__nvvm_sub_rn_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rn.ftz.sat.f
__nvvm_sub_rn_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.sat.f
__nvvm_sub_rz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.ftz.sat.f
__nvvm_sub_rz_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.sat.f
__nvvm_sub_rm_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.ftz.sat.f
__nvvm_sub_rm_ftz_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.sat.f
__nvvm_sub_rp_sat_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.ftz.sat.f
__nvvm_sub_rp_ftz_sat_f(1.0f, 2.0f);

// CHECK: call float @llvm.nvvm.fma.rn.sat.f
__nvvm_fma_rn_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rn.ftz.sat.f
__nvvm_fma_rn_ftz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rz.sat.f
__nvvm_fma_rz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rz.ftz.sat.f
__nvvm_fma_rz_ftz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rm.sat.f
__nvvm_fma_rm_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rm.ftz.sat.f
__nvvm_fma_rm_ftz_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rp.sat.f
__nvvm_fma_rp_sat_f(1.0f, 2.0f, 3.0f);
// CHECK: call float @llvm.nvvm.fma.rp.ftz.sat.f
__nvvm_fma_rp_ftz_sat_f(1.0f, 2.0f, 3.0f);

// CHECK: ret void
}

// CHECK-LABEL: nvvm_sub_f32
__device__ void nvvm_sub_f32() {
// CHECK: call float @llvm.nvvm.sub.rn.f
__nvvm_sub_rn_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rn.ftz.f
__nvvm_sub_rn_ftz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.f
__nvvm_sub_rz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rz.ftz.f
__nvvm_sub_rz_ftz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.f
__nvvm_sub_rm_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rm.ftz.f
__nvvm_sub_rm_ftz_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.f
__nvvm_sub_rp_f(1.0f, 2.0f);
// CHECK: call float @llvm.nvvm.sub.rp.ftz.f
__nvvm_sub_rp_ftz_f(1.0f, 2.0f);

// CHECK: ret void
}

// CHECK-LABEL: nvvm_sub_f64
__device__ void nvvm_sub_f64() {
// CHECK: call double @llvm.nvvm.sub.rn.d
__nvvm_sub_rn_d(1.0f, 2.0f);
// CHECK: call double @llvm.nvvm.sub.rz.d
__nvvm_sub_rz_d(1.0f, 2.0f);
// CHECK: call double @llvm.nvvm.sub.rm.d
__nvvm_sub_rm_d(1.0f, 2.0f);
// CHECK: call double @llvm.nvvm.sub.rp.d
__nvvm_sub_rp_d(1.0f, 2.0f);

// CHECK: ret void
}
48 changes: 33 additions & 15 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1376,16 +1376,18 @@ let TargetPrefix = "nvvm" in {
} // ftz
} // variant

foreach rnd = ["rn", "rz", "rm", "rp"] in {
foreach ftz = ["", "_ftz"] in
def int_nvvm_fma_ # rnd # ftz # _f : NVVMBuiltin,
PureIntrinsic<[llvm_float_ty],
[llvm_float_ty, llvm_float_ty, llvm_float_ty]>;

def int_nvvm_fma_ # rnd # _d : NVVMBuiltin,
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like .sat is supported for f16 as well. Do we want to make this an overloaded intrinsic so that both types are supported?

Copy link
Contributor Author

@Wolfram70 Wolfram70 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like only rn rounding mode is supported for the half-precision instructions so only those variants with rn can be overloaded. So, I was thinking we could keep them as separate intrinsics for now.

Edit: Same for add and sub.

def int_nvvm_fma # rnd # ftz # sat # _f : NVVMBuiltin,
PureIntrinsic<[llvm_float_ty],
[llvm_float_ty, llvm_float_ty, llvm_float_ty]>;
} // sat
} // ftz
def int_nvvm_fma # rnd # _d : NVVMBuiltin,
PureIntrinsic<[llvm_double_ty],
[llvm_double_ty, llvm_double_ty, llvm_double_ty]>;
}
} // rnd

//
// Rcp
Expand Down Expand Up @@ -1443,15 +1445,31 @@ let TargetPrefix = "nvvm" in {
// Add
//
let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in {
foreach rnd = ["rn", "rz", "rm", "rp"] in {
foreach ftz = ["", "_ftz"] in
def int_nvvm_add_ # rnd # ftz # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;

def int_nvvm_add_ # rnd # _d : NVVMBuiltin,
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in {
def int_nvvm_add # rnd # ftz # sat # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
} // sat
} // ftz
def int_nvvm_add # rnd # _d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
}
} // rnd
}

//
// Sub
//
foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach ftz = ["", "_ftz"] in {
foreach sat = ["", "_sat"] in {
def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin,
PureIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>;
} // sat
} // ftz
def int_nvvm_sub # rnd # _d : NVVMBuiltin,
PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>;
} // rnd

//
// Dot Product
Expand Down
114 changes: 114 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1637,13 +1637,21 @@ multiclass FMA_INST {
FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, B64>,

FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, B32>,
FMA_TUPLE<"_rn_ftz_sat_f32", int_nvvm_fma_rn_ftz_sat_f, B32>,
FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, B32>,
FMA_TUPLE<"_rn_sat_f32", int_nvvm_fma_rn_sat_f, B32>,
FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, B32>,
FMA_TUPLE<"_rz_ftz_sat_f32", int_nvvm_fma_rz_ftz_sat_f, B32>,
FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, B32>,
FMA_TUPLE<"_rz_sat_f32", int_nvvm_fma_rz_sat_f, B32>,
FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, B32>,
FMA_TUPLE<"_rm_sat_f32", int_nvvm_fma_rm_sat_f, B32>,
FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, B32>,
FMA_TUPLE<"_rm_ftz_sat_f32", int_nvvm_fma_rm_ftz_sat_f, B32>,
FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, B32>,
FMA_TUPLE<"_rp_sat_f32", int_nvvm_fma_rp_sat_f, B32>,
FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, B32>,
FMA_TUPLE<"_rp_ftz_sat_f32", int_nvvm_fma_rp_ftz_sat_f, B32>,

FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, B16, [hasPTX<42>, hasSM<53>]>,
FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, B16,
Expand Down Expand Up @@ -1694,6 +1702,32 @@ multiclass FMA_INST {

defm INT_NVVM_FMA : FMA_INST;

foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
foreach type = [f16, bf16] in {
def INT_NVVM_MIXED_FMA # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c),
!subst("_", ".", "fma" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
(!cast<Intrinsic>("int_nvvm_fma" # rnd # sat # "_f")
(f32 (fpextend type:$a)),
(f32 (fpextend type:$b)),
f32:$c))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}

// Pattern for llvm.fma.f32 intrinsic when there is no FTZ flag
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
def : Pat<(f32 (fma (f32 (fpextend f16:$a)),
(f32 (fpextend f16:$b)), f32:$c)),
(INT_NVVM_MIXED_FMA_rn_f32_f16 B16:$a, B16:$b, B32:$c)>;
def : Pat<(f32 (fma (f32 (fpextend bf16:$a)),
(f32 (fpextend bf16:$b)), f32:$c)),
(INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>;
}

//
// Rcp
//
Expand Down Expand Up @@ -1793,19 +1827,99 @@ let Predicates = [doRsqrtOpt] in {
//

def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>;
def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>;
def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>;
def INT_NVVM_ADD_RN_SAT_F : F_MATH_2<"add.rn.sat.f32", B32, B32, B32, int_nvvm_add_rn_sat_f>;
def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>;
def INT_NVVM_ADD_RZ_SAT_FTZ_F : F_MATH_2<"add.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_sat_f>;
def INT_NVVM_ADD_RZ_F : F_MATH_2<"add.rz.f32", B32, B32, B32, int_nvvm_add_rz_f>;
def INT_NVVM_ADD_RZ_SAT_F : F_MATH_2<"add.rz.sat.f32", B32, B32, B32, int_nvvm_add_rz_sat_f>;
def INT_NVVM_ADD_RM_FTZ_F : F_MATH_2<"add.rm.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_f>;
def INT_NVVM_ADD_RM_SAT_FTZ_F : F_MATH_2<"add.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_sat_f>;
def INT_NVVM_ADD_RM_F : F_MATH_2<"add.rm.f32", B32, B32, B32, int_nvvm_add_rm_f>;
def INT_NVVM_ADD_RM_SAT_F : F_MATH_2<"add.rm.sat.f32", B32, B32, B32, int_nvvm_add_rm_sat_f>;
def INT_NVVM_ADD_RP_FTZ_F : F_MATH_2<"add.rp.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_f>;
def INT_NVVM_ADD_RP_SAT_FTZ_F : F_MATH_2<"add.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_sat_f>;
def INT_NVVM_ADD_RP_F : F_MATH_2<"add.rp.f32", B32, B32, B32, int_nvvm_add_rp_f>;
def INT_NVVM_ADD_RP_SAT_F : F_MATH_2<"add.rp.sat.f32", B32, B32, B32, int_nvvm_add_rp_sat_f>;

def INT_NVVM_ADD_RN_D : F_MATH_2<"add.rn.f64", B64, B64, B64, int_nvvm_add_rn_d>;
def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>;
def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>;
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>;

foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
foreach type = [f16, bf16] in {
def INT_NVVM_MIXED_ADD # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the mixed FP instructions, should we also add folds for the generic fadd/fsub/fma instructions in LLVM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I've added patterns to fold fadd, fsub, and llvm.fma.32 to the mixed precision instructions when ftz isn't present. Please take a look, thanks!

!subst("_", ".", "add" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
(!cast<Intrinsic>("int_nvvm_add" # rnd # sat # "_f")
(f32 (fpextend type:$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}

// Pattern for fadd when there is no FTZ flag
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
def : Pat<(f32 (fadd (f32 (fpextend f16:$a)), f32:$b)),
(INT_NVVM_MIXED_ADD_rn_f32_f16 B16:$a, B32:$b)>;
def : Pat<(f32 (fadd (f32 (fpextend bf16:$a)), f32:$b)),
(INT_NVVM_MIXED_ADD_rn_f32_bf16 B16:$a, B32:$b)>;
}

//
// Sub
//

def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>;
def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>;
def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>;
def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>;
def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>;
def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>;
def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>;
def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>;
def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>;
def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>;
def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>;
def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>;
def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>;
def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>;
def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>;
def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>;

def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>;
def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>;
def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>;
def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>;

foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in {
foreach sat = ["", "_sat"] in {
foreach type = [f16, bf16] in {
def INT_NVVM_MIXED_SUB # rnd # sat # _f32_ # type :
BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b),
!subst("_", ".", "sub" # rnd # sat # "_f32_" # type),
[(set f32:$dst,
(!cast<Intrinsic>("int_nvvm_sub" # rnd # sat # "_f")
(f32 (fpextend type:$a)),
f32:$b))]>,
Requires<[hasSM<100>, hasPTX<86>]>;
}
}
}

// Pattern for fsub when there is no FTZ flag
let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in {
def : Pat<(f32 (fsub (f32 (fpextend f16:$a)), f32:$b)),
(INT_NVVM_MIXED_SUB_rn_f32_f16 B16:$a, B32:$b)>;
def : Pat<(f32 (fsub (f32 (fpextend bf16:$a)), f32:$b)),
(INT_NVVM_MIXED_SUB_rn_f32_bf16 B16:$a, B32:$b)>;
}

//
// BFIND
//
Expand Down
Loading