Skip to content

Commit 77a5de0

Browse files
committed
[NVPTX] lower VECREDUCE max/min to 3-input on sm_100+
Add support for 3-input fmaxnum/fminnum/fmaximum/fminimum introduced in PTX 8.8 for sm_100+. Other improvements are: - Fix lowering of fmaxnum/fminnum so that they are not reassociated. According to the IEEE 754 definition of these (maxNum/minNum), they are not associative due to how they handle sNaNs. A quick example: a = 1.0, b = 1.0, c = sNaN maxNum(a, maxNum(b, c)) = maxNum(a, qNaN) = a = 1.0 maxNum(maxNum(a, b), c) = maxNum(1.0, sNaN) = qNaN - Use a tree reduction when 3-input operations are supported and the reduction has the `reassoc`. - If not on sm_100+/PTX 8.8, fallback to 2-input operations and use the default shuffle reduction.
1 parent 6193dd5 commit 77a5de0

File tree

5 files changed

+993
-665
lines changed

5 files changed

+993
-665
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
850850
if (STI.allowFP16Math() || STI.hasBF16Math())
851851
setTargetDAGCombine(ISD::SETCC);
852852

853+
// Vector reduction operations. These may be turned into sequential, shuffle,
854+
// or tree reductions depending on what instructions are available for each
855+
// type.
856+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
857+
MVT EltVT = VT.getVectorElementType();
858+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
859+
EltVT == MVT::f64) {
860+
setOperationAction({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
861+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
862+
VT, Custom);
863+
}
864+
}
865+
853866
// Promote fp16 arithmetic if fp16 hardware isn't available or the
854867
// user passed --nvptx-no-fp16-math. The flag is useful because,
855868
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1096,6 +1109,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10961109
MAKE_CASE(NVPTXISD::BFI)
10971110
MAKE_CASE(NVPTXISD::PRMT)
10981111
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1112+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1113+
MAKE_CASE(NVPTXISD::FMINNUM3)
1114+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1115+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10991116
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
11001117
MAKE_CASE(NVPTXISD::STACKRESTORE)
11011118
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2078,6 +2095,191 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
20782095
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
20792096
}
20802097

2098+
/// A generic routine for constructing a tree reduction on a vector operand.
2099+
/// This method groups elements bottom-up, progressively building each level.
2100+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2101+
/// adjacent elements are combined first, leading to shorter live ranges. This
2102+
/// approach makes the most sense if the shuffle reduction would use the same
2103+
/// amount of registers.
2104+
///
2105+
/// The flags on the original reduction operation will be propagated to
2106+
/// each scalar operation.
2107+
static SDValue BuildTreeReduction(
2108+
const SmallVector<SDValue> &Elements, EVT EltTy,
2109+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2110+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2111+
// Build the reduction tree at each level, starting with all the elements.
2112+
SmallVector<SDValue> Level = Elements;
2113+
2114+
unsigned OpIdx = 0;
2115+
while (Level.size() > 1) {
2116+
// Try to reduce this level using the current operator.
2117+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2118+
2119+
// Build the next level by partially reducing all elements.
2120+
SmallVector<SDValue> ReducedLevel;
2121+
unsigned I = 0, E = Level.size();
2122+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2123+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2124+
ReducedLevel.push_back(DAG.getNode(
2125+
DefaultScalarOp, DL, EltTy,
2126+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2127+
}
2128+
2129+
if (I < E) {
2130+
// Handle leftover elements.
2131+
2132+
if (ReducedLevel.empty()) {
2133+
// We didn't reduce anything at this level. We need to pick a smaller
2134+
// operator.
2135+
++OpIdx;
2136+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2137+
continue;
2138+
}
2139+
2140+
// We reduced some things but there's still more left, meaning the
2141+
// operator's number of inputs doesn't evenly divide this level size. Move
2142+
// these elements to the next level.
2143+
for (; I < E; ++I)
2144+
ReducedLevel.push_back(Level[I]);
2145+
}
2146+
2147+
// Process the next level.
2148+
Level = ReducedLevel;
2149+
}
2150+
2151+
return *Level.begin();
2152+
}
2153+
2154+
/// Lower reductions to either a sequence of operations or a tree if
2155+
/// reassociations are allowed. This method will use larger operations like
2156+
/// max3/min3 when the target supports them.
2157+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2158+
SelectionDAG &DAG) const {
2159+
SDLoc DL(Op);
2160+
const SDNodeFlags Flags = Op->getFlags();
2161+
SDValue Vector = Op.getOperand(0);
2162+
SDValue Accumulator;
2163+
2164+
EVT EltTy = Vector.getValueType().getVectorElementType();
2165+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2166+
STI.getPTXVersion() >= 88;
2167+
2168+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2169+
// number of inputs they take.
2170+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2171+
2172+
// Whether we can lower to scalar operations in an arbitrary order.
2173+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
2174+
2175+
// Whether the data type and operation can be represented with fewer ops and
2176+
// registers in a shuffle reduction.
2177+
bool PrefersShuffle;
2178+
2179+
switch (Op->getOpcode()) {
2180+
case ISD::VECREDUCE_FMAX:
2181+
if (CanUseMinMax3) {
2182+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2183+
// Can't use fmaxnum3 in shuffle reduction
2184+
PrefersShuffle = false;
2185+
} else {
2186+
// Prefer max.{,b}f16x2 for v2{,b}f16
2187+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2188+
}
2189+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2190+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2191+
// of sNaN inputs.
2192+
IsAssociative = false;
2193+
break;
2194+
case ISD::VECREDUCE_FMIN:
2195+
if (CanUseMinMax3) {
2196+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2197+
// Can't use fminnum3 in shuffle reduction
2198+
PrefersShuffle = false;
2199+
} else {
2200+
// Prefer min.{,b}f16x2 for v2{,b}f16
2201+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2202+
}
2203+
ScalarOps.push_back({ISD::FMINNUM, 2});
2204+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2205+
// of sNaN inputs.
2206+
IsAssociative = false;
2207+
break;
2208+
case ISD::VECREDUCE_FMAXIMUM:
2209+
if (CanUseMinMax3) {
2210+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2211+
// Can't use fmax3 in shuffle reduction
2212+
PrefersShuffle = false;
2213+
} else {
2214+
// Prefer max.{,b}f16x2 for v2{,b}f16
2215+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2216+
}
2217+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2218+
IsAssociative = true;
2219+
break;
2220+
case ISD::VECREDUCE_FMINIMUM:
2221+
if (CanUseMinMax3) {
2222+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2223+
// Can't use fmin3 in shuffle reduction
2224+
PrefersShuffle = false;
2225+
} else {
2226+
// Prefer min.{,b}f16x2 for v2{,b}f16
2227+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2228+
}
2229+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2230+
IsAssociative = true;
2231+
break;
2232+
default:
2233+
llvm_unreachable("unhandled vecreduce operation");
2234+
}
2235+
2236+
// We don't expect an accumulator for reassociative vector reduction ops.
2237+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2238+
2239+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2240+
if (IsAssociative && PrefersShuffle)
2241+
return SDValue();
2242+
2243+
// Otherwise, handle the reduction here.
2244+
SmallVector<SDValue> Elements;
2245+
DAG.ExtractVectorElements(Vector, Elements);
2246+
2247+
// Lower to tree reduction.
2248+
if (IsAssociative)
2249+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2250+
2251+
// Lower to sequential reduction.
2252+
EVT VectorTy = Vector.getValueType();
2253+
const unsigned NumElts = VectorTy.getVectorNumElements();
2254+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2255+
// Try to reduce the remaining sequence as much as possible using the
2256+
// current operator.
2257+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2258+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2259+
2260+
if (!Accumulator) {
2261+
// Try to initialize the accumulator using the current operator.
2262+
if (I + DefaultGroupSize <= NumElts) {
2263+
Accumulator = DAG.getNode(
2264+
DefaultScalarOp, DL, EltTy,
2265+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2266+
I += DefaultGroupSize;
2267+
}
2268+
}
2269+
2270+
if (Accumulator) {
2271+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2272+
SmallVector<SDValue> Operands = {Accumulator};
2273+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2274+
Operands.push_back(Elements[I + K]);
2275+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2276+
}
2277+
}
2278+
}
2279+
2280+
return Accumulator;
2281+
}
2282+
20812283
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20822284
// Handle bitcasting from v2i8 without hitting the default promotion
20832285
// strategy which goes through stack memory.
@@ -2957,6 +3159,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29573159
return LowerVECTOR_SHUFFLE(Op, DAG);
29583160
case ISD::CONCAT_VECTORS:
29593161
return LowerCONCAT_VECTORS(Op, DAG);
3162+
case ISD::VECREDUCE_FMAX:
3163+
case ISD::VECREDUCE_FMIN:
3164+
case ISD::VECREDUCE_FMAXIMUM:
3165+
case ISD::VECREDUCE_FMINIMUM:
3166+
return LowerVECREDUCE(Op, DAG);
29603167
case ISD::STORE:
29613168
return LowerSTORE(Op, DAG);
29623169
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ enum NodeType : unsigned {
6464
UNPACK_VECTOR,
6565

6666
FCOPYSIGN,
67+
FMAXNUM3,
68+
FMINNUM3,
69+
FMAXIMUM3,
70+
FMINIMUM3,
71+
6772
DYNAMIC_STACKALLOC,
6873
STACKRESTORE,
6974
STACKSAVE,
@@ -287,6 +292,7 @@ class NVPTXTargetLowering : public TargetLowering {
287292

288293
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
289294
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
295+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
290296
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
291297
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
292298
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,36 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
358358
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
359359
}
360360

361+
// Template for 3-input minimum/maximum instructions
362+
// (sm_100+/PTX 8.8 and f32 only)
363+
//
364+
// Also defines ftz (flush subnormal inputs and results to sign-preserving
365+
// zero) variants for fp32 functions.
366+
multiclass FMINIMUMMAXIMUM3<string OpcStr, bit NaN, SDNode OpNode> {
367+
defvar nan_str = !if(NaN, ".NaN", "");
368+
def f32rrr :
369+
BasicFlagsNVPTXInst<(outs B32:$dst),
370+
(ins B32:$a, B32:$b, B32:$c),
371+
(ins FTZFlag:$ftz),
372+
OpcStr # "$ftz" # nan_str # ".f32",
373+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
374+
Requires<[hasPTX<88>, hasSM<100>]>;
375+
def f32rri :
376+
BasicFlagsNVPTXInst<(outs B32:$dst),
377+
(ins B32:$a, B32:$b, f32imm:$c),
378+
(ins FTZFlag:$ftz),
379+
OpcStr # "$ftz" # nan_str # ".f32",
380+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
381+
Requires<[hasPTX<88>, hasSM<100>]>;
382+
def f32rii :
383+
BasicFlagsNVPTXInst<(outs B32:$dst),
384+
(ins B32:$a, f32imm:$b, f32imm:$c),
385+
(ins FTZFlag:$ftz),
386+
OpcStr # "$ftz" # nan_str # ".f32",
387+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
388+
Requires<[hasPTX<88>, hasSM<100>]>;
389+
}
390+
361391
// Template for instructions which take three FP args. The
362392
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
363393
//
@@ -1027,6 +1057,20 @@ defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
10271057
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
10281058
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;
10291059

1060+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1061+
[SDNPCommutative]>;
1062+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1063+
[SDNPCommutative]>;
1064+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1065+
[SDNPCommutative]>;
1066+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1067+
[SDNPCommutative]>;
1068+
1069+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ false, nvptx_fminnum3>;
1070+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", /* NaN */ false, nvptx_fmaxnum3>;
1071+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ true, nvptx_fminimum3>;
1072+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max", /* NaN */ true, nvptx_fmaximum3>;
1073+
10301074
defm FABS : F2<"abs", fabs>;
10311075
defm FNEG : F2<"neg", fneg>;
10321076
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
8787
}
8888
unsigned getMinVectorRegisterBitWidth() const override { return 32; }
8989

90+
bool shouldExpandReduction(const IntrinsicInst *II) const override {
91+
return false;
92+
}
93+
9094
// We don't want to prevent inlining because of target-cpu and -features
9195
// attributes that were added to newer versions of LLVM/Clang: There are
9296
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)