Skip to content

Commit 2e56936

Browse files
LewisCrawfordgithub-actions[bot]
authored andcommitted
Automerge: [NVPTX] Constant fold NVVM fmin and fmax (#121966)
Add constant-folding for nvvm float/double fmin + fmax intrinsics, including all combinations of xorsign.abs, nan-propagation, and ftz.
2 parents 110af9b + cea9244 commit 2e56936

File tree

3 files changed

+1222
-8
lines changed

3 files changed

+1222
-8
lines changed

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 168 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ enum class TMAReductionOp : uint8_t {
3838
XOR = 7,
3939
};
4040

41-
inline bool IntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
41+
inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
4242
switch (IntrinsicID) {
43-
// Float to i32 / i64 conversion intrinsics:
4443
case Intrinsic::nvvm_f2i_rm_ftz:
4544
case Intrinsic::nvvm_f2i_rn_ftz:
4645
case Intrinsic::nvvm_f2i_rp_ftz:
@@ -61,11 +60,53 @@ inline bool IntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
6160
case Intrinsic::nvvm_f2ull_rp_ftz:
6261
case Intrinsic::nvvm_f2ull_rz_ftz:
6362
return true;
63+
64+
case Intrinsic::nvvm_f2i_rm:
65+
case Intrinsic::nvvm_f2i_rn:
66+
case Intrinsic::nvvm_f2i_rp:
67+
case Intrinsic::nvvm_f2i_rz:
68+
69+
case Intrinsic::nvvm_f2ui_rm:
70+
case Intrinsic::nvvm_f2ui_rn:
71+
case Intrinsic::nvvm_f2ui_rp:
72+
case Intrinsic::nvvm_f2ui_rz:
73+
74+
case Intrinsic::nvvm_d2i_rm:
75+
case Intrinsic::nvvm_d2i_rn:
76+
case Intrinsic::nvvm_d2i_rp:
77+
case Intrinsic::nvvm_d2i_rz:
78+
79+
case Intrinsic::nvvm_d2ui_rm:
80+
case Intrinsic::nvvm_d2ui_rn:
81+
case Intrinsic::nvvm_d2ui_rp:
82+
case Intrinsic::nvvm_d2ui_rz:
83+
84+
case Intrinsic::nvvm_f2ll_rm:
85+
case Intrinsic::nvvm_f2ll_rn:
86+
case Intrinsic::nvvm_f2ll_rp:
87+
case Intrinsic::nvvm_f2ll_rz:
88+
89+
case Intrinsic::nvvm_f2ull_rm:
90+
case Intrinsic::nvvm_f2ull_rn:
91+
case Intrinsic::nvvm_f2ull_rp:
92+
case Intrinsic::nvvm_f2ull_rz:
93+
94+
case Intrinsic::nvvm_d2ll_rm:
95+
case Intrinsic::nvvm_d2ll_rn:
96+
case Intrinsic::nvvm_d2ll_rp:
97+
case Intrinsic::nvvm_d2ll_rz:
98+
99+
case Intrinsic::nvvm_d2ull_rm:
100+
case Intrinsic::nvvm_d2ull_rn:
101+
case Intrinsic::nvvm_d2ull_rp:
102+
case Intrinsic::nvvm_d2ull_rz:
103+
return false;
64104
}
105+
llvm_unreachable("Checking FTZ flag for invalid f2i/d2i intrinsic");
65106
return false;
66107
}
67108

68-
inline bool IntrinsicConvertsToSignedInteger(Intrinsic::ID IntrinsicID) {
109+
inline bool FPToIntegerIntrinsicResultIsSigned(Intrinsic::ID IntrinsicID) {
69110
switch (IntrinsicID) {
70111
// f2i
71112
case Intrinsic::nvvm_f2i_rm:
@@ -96,12 +137,44 @@ inline bool IntrinsicConvertsToSignedInteger(Intrinsic::ID IntrinsicID) {
96137
case Intrinsic::nvvm_d2ll_rp:
97138
case Intrinsic::nvvm_d2ll_rz:
98139
return true;
140+
141+
// f2ui
142+
case Intrinsic::nvvm_f2ui_rm:
143+
case Intrinsic::nvvm_f2ui_rm_ftz:
144+
case Intrinsic::nvvm_f2ui_rn:
145+
case Intrinsic::nvvm_f2ui_rn_ftz:
146+
case Intrinsic::nvvm_f2ui_rp:
147+
case Intrinsic::nvvm_f2ui_rp_ftz:
148+
case Intrinsic::nvvm_f2ui_rz:
149+
case Intrinsic::nvvm_f2ui_rz_ftz:
150+
// d2ui
151+
case Intrinsic::nvvm_d2ui_rm:
152+
case Intrinsic::nvvm_d2ui_rn:
153+
case Intrinsic::nvvm_d2ui_rp:
154+
case Intrinsic::nvvm_d2ui_rz:
155+
// f2ull
156+
case Intrinsic::nvvm_f2ull_rm:
157+
case Intrinsic::nvvm_f2ull_rm_ftz:
158+
case Intrinsic::nvvm_f2ull_rn:
159+
case Intrinsic::nvvm_f2ull_rn_ftz:
160+
case Intrinsic::nvvm_f2ull_rp:
161+
case Intrinsic::nvvm_f2ull_rp_ftz:
162+
case Intrinsic::nvvm_f2ull_rz:
163+
case Intrinsic::nvvm_f2ull_rz_ftz:
164+
// d2ull
165+
case Intrinsic::nvvm_d2ull_rm:
166+
case Intrinsic::nvvm_d2ull_rn:
167+
case Intrinsic::nvvm_d2ull_rp:
168+
case Intrinsic::nvvm_d2ull_rz:
169+
return false;
99170
}
171+
llvm_unreachable(
172+
"Checking invalid f2i/d2i intrinsic for signed int conversion");
100173
return false;
101174
}
102175

103176
inline APFloat::roundingMode
104-
IntrinsicGetRoundingMode(Intrinsic::ID IntrinsicID) {
177+
GetFPToIntegerRoundingMode(Intrinsic::ID IntrinsicID) {
105178
switch (IntrinsicID) {
106179
// RM:
107180
case Intrinsic::nvvm_f2i_rm:
@@ -167,10 +240,100 @@ IntrinsicGetRoundingMode(Intrinsic::ID IntrinsicID) {
167240
case Intrinsic::nvvm_d2ull_rz:
168241
return APFloat::rmTowardZero;
169242
}
170-
llvm_unreachable("Invalid f2i/d2i rounding mode intrinsic");
243+
llvm_unreachable("Checking rounding mode for invalid f2i/d2i intrinsic");
171244
return APFloat::roundingMode::Invalid;
172245
}
173246

247+
inline bool FMinFMaxShouldFTZ(Intrinsic::ID IntrinsicID) {
248+
switch (IntrinsicID) {
249+
case Intrinsic::nvvm_fmax_ftz_f:
250+
case Intrinsic::nvvm_fmax_ftz_nan_f:
251+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
252+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
253+
254+
case Intrinsic::nvvm_fmin_ftz_f:
255+
case Intrinsic::nvvm_fmin_ftz_nan_f:
256+
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
257+
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
258+
return true;
259+
260+
case Intrinsic::nvvm_fmax_d:
261+
case Intrinsic::nvvm_fmax_f:
262+
case Intrinsic::nvvm_fmax_nan_f:
263+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
264+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
265+
266+
case Intrinsic::nvvm_fmin_d:
267+
case Intrinsic::nvvm_fmin_f:
268+
case Intrinsic::nvvm_fmin_nan_f:
269+
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
270+
case Intrinsic::nvvm_fmin_xorsign_abs_f:
271+
return false;
272+
}
273+
llvm_unreachable("Checking FTZ flag for invalid fmin/fmax intrinsic");
274+
return false;
275+
}
276+
277+
inline bool FMinFMaxPropagatesNaNs(Intrinsic::ID IntrinsicID) {
278+
switch (IntrinsicID) {
279+
case Intrinsic::nvvm_fmax_ftz_nan_f:
280+
case Intrinsic::nvvm_fmax_nan_f:
281+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
282+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
283+
284+
case Intrinsic::nvvm_fmin_ftz_nan_f:
285+
case Intrinsic::nvvm_fmin_nan_f:
286+
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
287+
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
288+
return true;
289+
290+
case Intrinsic::nvvm_fmax_d:
291+
case Intrinsic::nvvm_fmax_f:
292+
case Intrinsic::nvvm_fmax_ftz_f:
293+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
294+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
295+
296+
case Intrinsic::nvvm_fmin_d:
297+
case Intrinsic::nvvm_fmin_f:
298+
case Intrinsic::nvvm_fmin_ftz_f:
299+
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
300+
case Intrinsic::nvvm_fmin_xorsign_abs_f:
301+
return false;
302+
}
303+
llvm_unreachable("Checking NaN flag for invalid fmin/fmax intrinsic");
304+
return false;
305+
}
306+
307+
inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
308+
switch (IntrinsicID) {
309+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
310+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
311+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
312+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
313+
314+
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
315+
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
316+
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
317+
case Intrinsic::nvvm_fmin_xorsign_abs_f:
318+
return true;
319+
320+
case Intrinsic::nvvm_fmax_d:
321+
case Intrinsic::nvvm_fmax_f:
322+
case Intrinsic::nvvm_fmax_ftz_f:
323+
case Intrinsic::nvvm_fmax_ftz_nan_f:
324+
case Intrinsic::nvvm_fmax_nan_f:
325+
326+
case Intrinsic::nvvm_fmin_d:
327+
case Intrinsic::nvvm_fmin_f:
328+
case Intrinsic::nvvm_fmin_ftz_f:
329+
case Intrinsic::nvvm_fmin_ftz_nan_f:
330+
case Intrinsic::nvvm_fmin_nan_f:
331+
return false;
332+
}
333+
llvm_unreachable("Checking XorSignAbs flag for invalid fmin/fmax intrinsic");
334+
return false;
335+
}
336+
174337
} // namespace nvvm
175338
} // namespace llvm
176339
#endif // LLVM_IR_NVVMINTRINSICUTILS_H

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,28 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
16891689
case Intrinsic::x86_avx512_cvttsd2usi64:
16901690
return !Call->isStrictFP();
16911691

1692+
// NVVM FMax intrinsics
1693+
case Intrinsic::nvvm_fmax_d:
1694+
case Intrinsic::nvvm_fmax_f:
1695+
case Intrinsic::nvvm_fmax_ftz_f:
1696+
case Intrinsic::nvvm_fmax_ftz_nan_f:
1697+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
1698+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
1699+
case Intrinsic::nvvm_fmax_nan_f:
1700+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
1701+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
1702+
1703+
// NVVM FMin intrinsics
1704+
case Intrinsic::nvvm_fmin_d:
1705+
case Intrinsic::nvvm_fmin_f:
1706+
case Intrinsic::nvvm_fmin_ftz_f:
1707+
case Intrinsic::nvvm_fmin_ftz_nan_f:
1708+
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
1709+
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
1710+
case Intrinsic::nvvm_fmin_nan_f:
1711+
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
1712+
case Intrinsic::nvvm_fmin_xorsign_abs_f:
1713+
16921714
// NVVM float/double to int32/uint32 conversion intrinsics
16931715
case Intrinsic::nvvm_f2i_rm:
16941716
case Intrinsic::nvvm_f2i_rn:
@@ -2431,9 +2453,10 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
24312453
if (U.isNaN())
24322454
return ConstantInt::get(Ty, 0);
24332455

2434-
APFloat::roundingMode RMode = nvvm::IntrinsicGetRoundingMode(IntrinsicID);
2435-
bool IsFTZ = nvvm::IntrinsicShouldFTZ(IntrinsicID);
2436-
bool IsSigned = nvvm::IntrinsicConvertsToSignedInteger(IntrinsicID);
2456+
APFloat::roundingMode RMode =
2457+
nvvm::GetFPToIntegerRoundingMode(IntrinsicID);
2458+
bool IsFTZ = nvvm::FPToIntegerIntrinsicShouldFTZ(IntrinsicID);
2459+
bool IsSigned = nvvm::FPToIntegerIntrinsicResultIsSigned(IntrinsicID);
24372460

24382461
APSInt ResInt(Ty->getIntegerBitWidth(), !IsSigned);
24392462
auto FloatToRound = IsFTZ ? FTZPreserveSign(U) : U;
@@ -2892,12 +2915,49 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
28922915
case Intrinsic::minnum:
28932916
case Intrinsic::maximum:
28942917
case Intrinsic::minimum:
2918+
case Intrinsic::nvvm_fmax_d:
2919+
case Intrinsic::nvvm_fmin_d:
28952920
// If one argument is undef, return the other argument.
28962921
if (IsOp0Undef)
28972922
return Operands[1];
28982923
if (IsOp1Undef)
28992924
return Operands[0];
29002925
break;
2926+
2927+
case Intrinsic::nvvm_fmax_f:
2928+
case Intrinsic::nvvm_fmax_ftz_f:
2929+
case Intrinsic::nvvm_fmax_ftz_nan_f:
2930+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
2931+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
2932+
case Intrinsic::nvvm_fmax_nan_f:
2933+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
2934+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
2935+
2936+
case Intrinsic::nvvm_fmin_f:
2937+
case Intrinsic::nvvm_fmin_ftz_f:
2938+
case Intrinsic::nvvm_fmin_ftz_nan_f:
2939+
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
2940+
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
2941+
case Intrinsic::nvvm_fmin_nan_f:
2942+
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
2943+
case Intrinsic::nvvm_fmin_xorsign_abs_f:
2944+
// If one arg is undef, the other arg can be returned only if it is
2945+
// constant, as we may need to flush it to sign-preserving zero or
2946+
// canonicalize the NaN.
2947+
if (!IsOp0Undef && !IsOp1Undef)
2948+
break;
2949+
if (auto *Op = dyn_cast<ConstantFP>(Operands[IsOp0Undef ? 1 : 0])) {
2950+
if (Op->isNaN()) {
2951+
APInt NVCanonicalNaN(32, 0x7fffffff);
2952+
return ConstantFP::get(
2953+
Ty, APFloat(Ty->getFltSemantics(), NVCanonicalNaN));
2954+
}
2955+
if (nvvm::FMinFMaxShouldFTZ(IntrinsicID))
2956+
return ConstantFP::get(Ty, FTZPreserveSign(Op->getValueAPF()));
2957+
else
2958+
return Op;
2959+
}
2960+
break;
29012961
}
29022962
}
29032963

@@ -2955,6 +3015,79 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
29553015
return ConstantFP::get(Ty->getContext(), minimum(Op1V, Op2V));
29563016
case Intrinsic::maximum:
29573017
return ConstantFP::get(Ty->getContext(), maximum(Op1V, Op2V));
3018+
3019+
case Intrinsic::nvvm_fmax_d:
3020+
case Intrinsic::nvvm_fmax_f:
3021+
case Intrinsic::nvvm_fmax_ftz_f:
3022+
case Intrinsic::nvvm_fmax_ftz_nan_f:
3023+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
3024+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
3025+
case Intrinsic::nvvm_fmax_nan_f:
3026+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
3027+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
3028+
3029+
case Intrinsic::nvvm_fmin_d:
3030+
case Intrinsic::nvvm_fmin_f:
3031+
case Intrinsic::nvvm_fmin_ftz_f:
3032+
case Intrinsic::nvvm_fmin_ftz_nan_f:
3033+
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
3034+
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
3035+
case Intrinsic::nvvm_fmin_nan_f:
3036+
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
3037+
case Intrinsic::nvvm_fmin_xorsign_abs_f: {
3038+
3039+
bool ShouldCanonicalizeNaNs = !(IntrinsicID == Intrinsic::nvvm_fmax_d ||
3040+
IntrinsicID == Intrinsic::nvvm_fmin_d);
3041+
bool IsFTZ = nvvm::FMinFMaxShouldFTZ(IntrinsicID);
3042+
bool IsNaNPropagating = nvvm::FMinFMaxPropagatesNaNs(IntrinsicID);
3043+
bool IsXorSignAbs = nvvm::FMinFMaxIsXorSignAbs(IntrinsicID);
3044+
3045+
APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
3046+
APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
3047+
3048+
bool XorSign = false;
3049+
if (IsXorSignAbs) {
3050+
XorSign = A.isNegative() ^ B.isNegative();
3051+
A = abs(A);
3052+
B = abs(B);
3053+
}
3054+
3055+
bool IsFMax = false;
3056+
switch (IntrinsicID) {
3057+
case Intrinsic::nvvm_fmax_d:
3058+
case Intrinsic::nvvm_fmax_f:
3059+
case Intrinsic::nvvm_fmax_ftz_f:
3060+
case Intrinsic::nvvm_fmax_ftz_nan_f:
3061+
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
3062+
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
3063+
case Intrinsic::nvvm_fmax_nan_f:
3064+
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
3065+
case Intrinsic::nvvm_fmax_xorsign_abs_f:
3066+
IsFMax = true;
3067+
break;
3068+
}
3069+
APFloat Res = IsFMax ? maximum(A, B) : minimum(A, B);
3070+
3071+
if (ShouldCanonicalizeNaNs) {
3072+
APFloat NVCanonicalNaN(Res.getSemantics(), APInt(32, 0x7fffffff));
3073+
if (A.isNaN() && B.isNaN())
3074+
return ConstantFP::get(Ty, NVCanonicalNaN);
3075+
else if (IsNaNPropagating && (A.isNaN() || B.isNaN()))
3076+
return ConstantFP::get(Ty, NVCanonicalNaN);
3077+
}
3078+
3079+
if (A.isNaN() && B.isNaN())
3080+
return Operands[1];
3081+
else if (A.isNaN())
3082+
Res = B;
3083+
else if (B.isNaN())
3084+
Res = A;
3085+
3086+
if (IsXorSignAbs && XorSign != Res.isNegative())
3087+
Res.changeSign();
3088+
3089+
return ConstantFP::get(Ty->getContext(), Res);
3090+
}
29583091
}
29593092

29603093
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())

0 commit comments

Comments
 (0)