-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[clang][NVPTX] Add support for mixed-precision FP arithmetic #168359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
02db2fe
444a0a7
9b02a28
98876d8
f72ed2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1637,13 +1637,21 @@ multiclass FMA_INST { | |
| FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, B64>, | ||
|
|
||
| FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, B32>, | ||
| FMA_TUPLE<"_rn_ftz_sat_f32", int_nvvm_fma_rn_ftz_sat_f, B32>, | ||
| FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, B32>, | ||
| FMA_TUPLE<"_rn_sat_f32", int_nvvm_fma_rn_sat_f, B32>, | ||
| FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, B32>, | ||
| FMA_TUPLE<"_rz_ftz_sat_f32", int_nvvm_fma_rz_ftz_sat_f, B32>, | ||
| FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, B32>, | ||
| FMA_TUPLE<"_rz_sat_f32", int_nvvm_fma_rz_sat_f, B32>, | ||
| FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, B32>, | ||
| FMA_TUPLE<"_rm_sat_f32", int_nvvm_fma_rm_sat_f, B32>, | ||
| FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, B32>, | ||
| FMA_TUPLE<"_rm_ftz_sat_f32", int_nvvm_fma_rm_ftz_sat_f, B32>, | ||
| FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, B32>, | ||
| FMA_TUPLE<"_rp_sat_f32", int_nvvm_fma_rp_sat_f, B32>, | ||
| FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, B32>, | ||
| FMA_TUPLE<"_rp_ftz_sat_f32", int_nvvm_fma_rp_ftz_sat_f, B32>, | ||
|
|
||
| FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, B16, [hasPTX<42>, hasSM<53>]>, | ||
| FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, B16, | ||
|
|
@@ -1694,6 +1702,22 @@ multiclass FMA_INST { | |
|
|
||
| defm INT_NVVM_FMA : FMA_INST; | ||
|
|
||
| foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { | ||
| foreach sat = ["", "_SAT"] in { | ||
| foreach type = ["F16", "BF16"] in { | ||
Wolfram70 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def INT_NVVM_FMA # rnd # sat # _F32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c), | ||
Wolfram70 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| !tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)), | ||
| [(set f32:$dst, | ||
| (!cast<Intrinsic>(!tolower("int_nvvm_fma" # rnd # sat # "_f")) | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$b)), | ||
| f32:$c))]>, | ||
| Requires<[hasSM<100>, hasPTX<86>]>; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // | ||
| // Rcp | ||
| // | ||
|
|
@@ -1793,19 +1817,81 @@ let Predicates = [doRsqrtOpt] in { | |
| // | ||
|
|
||
| def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; | ||
| def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; | ||
| def INT_NVVM_ADD_RN_SAT_F : F_MATH_2<"add.rn.sat.f32", B32, B32, B32, int_nvvm_add_rn_sat_f>; | ||
| def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>; | ||
| def INT_NVVM_ADD_RZ_SAT_FTZ_F : F_MATH_2<"add.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RZ_F : F_MATH_2<"add.rz.f32", B32, B32, B32, int_nvvm_add_rz_f>; | ||
| def INT_NVVM_ADD_RZ_SAT_F : F_MATH_2<"add.rz.sat.f32", B32, B32, B32, int_nvvm_add_rz_sat_f>; | ||
| def INT_NVVM_ADD_RM_FTZ_F : F_MATH_2<"add.rm.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_f>; | ||
| def INT_NVVM_ADD_RM_SAT_FTZ_F : F_MATH_2<"add.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RM_F : F_MATH_2<"add.rm.f32", B32, B32, B32, int_nvvm_add_rm_f>; | ||
| def INT_NVVM_ADD_RM_SAT_F : F_MATH_2<"add.rm.sat.f32", B32, B32, B32, int_nvvm_add_rm_sat_f>; | ||
| def INT_NVVM_ADD_RP_FTZ_F : F_MATH_2<"add.rp.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_f>; | ||
| def INT_NVVM_ADD_RP_SAT_FTZ_F : F_MATH_2<"add.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_sat_f>; | ||
| def INT_NVVM_ADD_RP_F : F_MATH_2<"add.rp.f32", B32, B32, B32, int_nvvm_add_rp_f>; | ||
| def INT_NVVM_ADD_RP_SAT_F : F_MATH_2<"add.rp.sat.f32", B32, B32, B32, int_nvvm_add_rp_sat_f>; | ||
|
|
||
| def INT_NVVM_ADD_RN_D : F_MATH_2<"add.rn.f64", B64, B64, B64, int_nvvm_add_rn_d>; | ||
| def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>; | ||
| def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>; | ||
| def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>; | ||
|
|
||
| foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { | ||
| foreach sat = ["", "_SAT"] in { | ||
| foreach type = ["F16", "BF16"] in { | ||
| def INT_NVVM_ADD # rnd # sat # _F32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the mixed FP instructions, should we also add folds for the generic fadd/fsub/fma instructions in LLVM?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. I've added patterns to fold |
||
| !tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)), | ||
| [(set f32:$dst, | ||
| (!cast<Intrinsic>(!tolower("int_nvvm_add" # rnd # sat # "_f")) | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), | ||
durga4github marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f32:$b))]>, | ||
| Requires<[hasSM<100>, hasPTX<86>]>; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Sub | ||
| // | ||
|
|
||
| def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>; | ||
| def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>; | ||
| def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>; | ||
| def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>; | ||
| def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>; | ||
| def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>; | ||
| def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>; | ||
| def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>; | ||
| def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>; | ||
| def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>; | ||
| def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>; | ||
| def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>; | ||
| def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>; | ||
|
|
||
| def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>; | ||
| def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>; | ||
| def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>; | ||
| def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>; | ||
|
|
||
| foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { | ||
| foreach sat = ["", "_SAT"] in { | ||
| foreach type = ["F16", "BF16"] in { | ||
| def INT_NVVM_SUB # rnd # sat # _F32_ # type : | ||
| BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), | ||
| !tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)), | ||
| [(set f32:$dst, | ||
| (!cast<Intrinsic>(!tolower("int_nvvm_sub" # rnd # sat # "_f")) | ||
| (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), | ||
| f32:$b))]>, | ||
| Requires<[hasSM<100>, hasPTX<86>]>; | ||
| } | ||
| } | ||
| } | ||
| // | ||
| // BFIND | ||
| // | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.