Skip to content

Commit 72554b0

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - The default shuffle reduction progressively halves and partially reduces the vector down until we reach a single element. This produces a sequence of operations that combine disparate elements of the vector. For example, `vecreduce_fadd <4 x f32><a b c d>` will give `(a + c) + (b + d)`, whereas the tree reduction produces (a + b) + (c + d) by grouping nearby elements together first. Both use the same number of registers, but the shuffle reduction has longer live ranges. The same example is graphed below. Note we hold onto 3 registers for 2 cycles in the shuffle reduction and 1 cycle in tree reduction. (shuffle reduction) PTX: %r1 = add.f32 a, c %r2 = add.f32 b, d %r3 = add.f32 %r1, %r3 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | | %r1 = add.f32 a, c | %r2 = add.f32 b, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | a b c [3R] | b d %r1 [3R] | %r1 %r2 [2R] | %r3 [1R] | (tree reduction) PTX: %r1 = add.f32 a, b %r2 = add.f32 c, d %r3 = add.f32 %r1, %r2 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | %r1 = add.f32 a, b | | %r2 = add.f32 c, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | c %r1 [2R] | c %r1 d [3R] | %r1 %r2 [2R] | %r3 [1R] | - The shuffle reduction cannot easily support fmax3/fmin3 because it progressively halves the input vector. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 4811552 commit 72554b0

File tree

5 files changed

+1107
-696
lines changed

5 files changed

+1107
-696
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
847847
if (STI.allowFP16Math() || STI.hasBF16Math())
848848
setTargetDAGCombine(ISD::SETCC);
849849

850+
// Vector reduction operations. These may be turned into sequential, shuffle,
851+
// or tree reductions depending on what instructions are available for each
852+
// type.
853+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
854+
MVT EltVT = VT.getVectorElementType();
855+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
856+
EltVT == MVT::f64) {
857+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
858+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
859+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
860+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
861+
VT, Custom);
862+
} else if (EltVT.isScalarInteger()) {
863+
setOperationAction(
864+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
865+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
866+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
867+
VT, Custom);
868+
}
869+
}
870+
850871
// Promote fp16 arithmetic if fp16 hardware isn't available or the
851872
// user passed --nvptx-no-fp16-math. The flag is useful because,
852873
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1091,6 +1112,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10911112
MAKE_CASE(NVPTXISD::BFI)
10921113
MAKE_CASE(NVPTXISD::PRMT)
10931114
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1115+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1116+
MAKE_CASE(NVPTXISD::FMINNUM3)
1117+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1118+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10941119
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10951120
MAKE_CASE(NVPTXISD::STACKRESTORE)
10961121
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2060,6 +2085,261 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20602085
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
20612086
}
20622087

2088+
/// A generic routine for constructing a tree reduction on a vector operand.
2089+
/// This method groups elements bottom-up, progressively building each level.
2090+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2091+
/// adjacent elements are combined first, leading to shorter live ranges. This
2092+
/// approach makes the most sense if the shuffle reduction would use the same
2093+
/// amount of registers.
2094+
///
2095+
/// The flags on the original reduction operation will be propagated to
2096+
/// each scalar operation.
2097+
static SDValue BuildTreeReduction(
2098+
const SmallVector<SDValue> &Elements, EVT EltTy,
2099+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2100+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2101+
// Build the reduction tree at each level, starting with all the elements.
2102+
SmallVector<SDValue> Level = Elements;
2103+
2104+
unsigned OpIdx = 0;
2105+
while (Level.size() > 1) {
2106+
// Try to reduce this level using the current operator.
2107+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2108+
2109+
// Build the next level by partially reducing all elements.
2110+
SmallVector<SDValue> ReducedLevel;
2111+
unsigned I = 0, E = Level.size();
2112+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2113+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2114+
ReducedLevel.push_back(DAG.getNode(
2115+
DefaultScalarOp, DL, EltTy,
2116+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2117+
}
2118+
2119+
if (I < E) {
2120+
// Handle leftover elements.
2121+
2122+
if (ReducedLevel.empty()) {
2123+
// We didn't reduce anything at this level. We need to pick a smaller
2124+
// operator.
2125+
++OpIdx;
2126+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2127+
continue;
2128+
}
2129+
2130+
// We reduced some things but there's still more left, meaning the
2131+
// operator's number of inputs doesn't evenly divide this level size. Move
2132+
// these elements to the next level.
2133+
for (; I < E; ++I)
2134+
ReducedLevel.push_back(Level[I]);
2135+
}
2136+
2137+
// Process the next level.
2138+
Level = ReducedLevel;
2139+
}
2140+
2141+
return *Level.begin();
2142+
}
2143+
2144+
/// Lower reductions to either a sequence of operations or a tree if
2145+
/// reassociations are allowed. This method will use larger operations like
2146+
/// max3/min3 when the target supports them.
2147+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2148+
SelectionDAG &DAG) const {
2149+
SDLoc DL(Op);
2150+
const SDNodeFlags Flags = Op->getFlags();
2151+
SDValue Vector;
2152+
SDValue Accumulator;
2153+
2154+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2155+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2156+
// special case with accumulator as first arg
2157+
Accumulator = Op.getOperand(0);
2158+
Vector = Op.getOperand(1);
2159+
} else {
2160+
// default case
2161+
Vector = Op.getOperand(0);
2162+
}
2163+
2164+
EVT EltTy = Vector.getValueType().getVectorElementType();
2165+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2166+
STI.getPTXVersion() >= 88;
2167+
2168+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2169+
// number of inputs they take.
2170+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2171+
2172+
// Whether we can lower to scalar operations in an arbitrary order.
2173+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
2174+
2175+
// Whether the data type and operation can be represented with fewer ops and
2176+
// registers in a shuffle reduction.
2177+
bool PrefersShuffle;
2178+
2179+
switch (Op->getOpcode()) {
2180+
case ISD::VECREDUCE_FADD:
2181+
case ISD::VECREDUCE_SEQ_FADD:
2182+
ScalarOps = {{ISD::FADD, 2}};
2183+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FADD;
2184+
// Prefer add.{f16,bf16,f32}x2 for v2{f16,bf16,f32}
2185+
PrefersShuffle =
2186+
EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32;
2187+
break;
2188+
case ISD::VECREDUCE_FMUL:
2189+
case ISD::VECREDUCE_SEQ_FMUL:
2190+
ScalarOps = {{ISD::FMUL, 2}};
2191+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FMUL;
2192+
// Prefer mul.{f16,bf16,f32}x2 for v2{f16,bf16,f32}
2193+
PrefersShuffle =
2194+
EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32;
2195+
break;
2196+
case ISD::VECREDUCE_FMAX:
2197+
if (CanUseMinMax3)
2198+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2199+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2200+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2201+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2202+
IsAssociative |= Flags.hasAllowReassociation();
2203+
PrefersShuffle = false;
2204+
break;
2205+
case ISD::VECREDUCE_FMIN:
2206+
if (CanUseMinMax3)
2207+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2208+
ScalarOps.push_back({ISD::FMINNUM, 2});
2209+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2210+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2211+
IsAssociative |= Flags.hasAllowReassociation();
2212+
PrefersShuffle = false;
2213+
break;
2214+
case ISD::VECREDUCE_FMAXIMUM:
2215+
if (CanUseMinMax3) {
2216+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2217+
// Can't use fmax3 in shuffle reduction
2218+
PrefersShuffle = false;
2219+
} else {
2220+
// Prefer max.{,b}f16x2 for v2{,b}f16
2221+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2222+
}
2223+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2224+
IsAssociative = true;
2225+
break;
2226+
case ISD::VECREDUCE_FMINIMUM:
2227+
if (CanUseMinMax3) {
2228+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2229+
// Can't use fmin3 in shuffle reduction
2230+
PrefersShuffle = false;
2231+
} else {
2232+
// Prefer min.{,b}f16x2 for v2{,b}f16
2233+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2234+
}
2235+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2236+
IsAssociative = true;
2237+
break;
2238+
case ISD::VECREDUCE_ADD:
2239+
ScalarOps = {{ISD::ADD, 2}};
2240+
IsAssociative = true;
2241+
// Prefer add.{s,u}16x2 for v2i16
2242+
PrefersShuffle = EltTy == MVT::i16;
2243+
break;
2244+
case ISD::VECREDUCE_MUL:
2245+
ScalarOps = {{ISD::MUL, 2}};
2246+
IsAssociative = true;
2247+
// Integer multiply doesn't support packed types
2248+
PrefersShuffle = false;
2249+
break;
2250+
case ISD::VECREDUCE_UMAX:
2251+
ScalarOps = {{ISD::UMAX, 2}};
2252+
IsAssociative = true;
2253+
// Prefer max.u16x2 for v2i16
2254+
PrefersShuffle = EltTy == MVT::i16;
2255+
break;
2256+
case ISD::VECREDUCE_UMIN:
2257+
ScalarOps = {{ISD::UMIN, 2}};
2258+
IsAssociative = true;
2259+
// Prefer min.u16x2 for v2i16
2260+
PrefersShuffle = EltTy == MVT::i16;
2261+
break;
2262+
case ISD::VECREDUCE_SMAX:
2263+
ScalarOps = {{ISD::SMAX, 2}};
2264+
IsAssociative = true;
2265+
// Prefer max.s16x2 for v2i16
2266+
PrefersShuffle = EltTy == MVT::i16;
2267+
break;
2268+
case ISD::VECREDUCE_SMIN:
2269+
ScalarOps = {{ISD::SMIN, 2}};
2270+
IsAssociative = true;
2271+
// Prefer min.s16x2 for v2i16
2272+
PrefersShuffle = EltTy == MVT::i16;
2273+
break;
2274+
case ISD::VECREDUCE_AND:
2275+
ScalarOps = {{ISD::AND, 2}};
2276+
IsAssociative = true;
2277+
// Prefer and.b32 for v2i16.
2278+
PrefersShuffle = EltTy == MVT::i16;
2279+
break;
2280+
case ISD::VECREDUCE_OR:
2281+
ScalarOps = {{ISD::OR, 2}};
2282+
IsAssociative = true;
2283+
// Prefer or.b32 for v2i16.
2284+
PrefersShuffle = EltTy == MVT::i16;
2285+
break;
2286+
case ISD::VECREDUCE_XOR:
2287+
ScalarOps = {{ISD::XOR, 2}};
2288+
IsAssociative = true;
2289+
// Prefer xor.b32 for v2i16.
2290+
PrefersShuffle = EltTy == MVT::i16;
2291+
break;
2292+
default:
2293+
llvm_unreachable("unhandled vecreduce operation");
2294+
}
2295+
2296+
// We don't expect an accumulator for reassociative vector reduction ops.
2297+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2298+
2299+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2300+
if (IsAssociative && PrefersShuffle)
2301+
return SDValue();
2302+
2303+
// Otherwise, handle the reduction here.
2304+
SmallVector<SDValue> Elements;
2305+
DAG.ExtractVectorElements(Vector, Elements);
2306+
2307+
// Lower to tree reduction.
2308+
if (IsAssociative)
2309+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2310+
2311+
// Lower to sequential reduction.
2312+
EVT VectorTy = Vector.getValueType();
2313+
const unsigned NumElts = VectorTy.getVectorNumElements();
2314+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2315+
// Try to reduce the remaining sequence as much as possible using the
2316+
// current operator.
2317+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2318+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2319+
2320+
if (!Accumulator) {
2321+
// Try to initialize the accumulator using the current operator.
2322+
if (I + DefaultGroupSize <= NumElts) {
2323+
Accumulator = DAG.getNode(
2324+
DefaultScalarOp, DL, EltTy,
2325+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2326+
I += DefaultGroupSize;
2327+
}
2328+
}
2329+
2330+
if (Accumulator) {
2331+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2332+
SmallVector<SDValue> Operands = {Accumulator};
2333+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2334+
Operands.push_back(Elements[I + K]);
2335+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2336+
}
2337+
}
2338+
}
2339+
2340+
return Accumulator;
2341+
}
2342+
20632343
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20642344
// Handle bitcasting from v2i8 without hitting the default promotion
20652345
// strategy which goes through stack memory.
@@ -2903,6 +3183,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29033183
return LowerVECTOR_SHUFFLE(Op, DAG);
29043184
case ISD::CONCAT_VECTORS:
29053185
return LowerCONCAT_VECTORS(Op, DAG);
3186+
case ISD::VECREDUCE_FADD:
3187+
case ISD::VECREDUCE_FMUL:
3188+
case ISD::VECREDUCE_SEQ_FADD:
3189+
case ISD::VECREDUCE_SEQ_FMUL:
3190+
case ISD::VECREDUCE_FMAX:
3191+
case ISD::VECREDUCE_FMIN:
3192+
case ISD::VECREDUCE_FMAXIMUM:
3193+
case ISD::VECREDUCE_FMINIMUM:
3194+
case ISD::VECREDUCE_ADD:
3195+
case ISD::VECREDUCE_MUL:
3196+
case ISD::VECREDUCE_UMAX:
3197+
case ISD::VECREDUCE_UMIN:
3198+
case ISD::VECREDUCE_SMAX:
3199+
case ISD::VECREDUCE_SMIN:
3200+
case ISD::VECREDUCE_AND:
3201+
case ISD::VECREDUCE_OR:
3202+
case ISD::VECREDUCE_XOR:
3203+
return LowerVECREDUCE(Op, DAG);
29063204
case ISD::STORE:
29073205
return LowerSTORE(Op, DAG);
29083206
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ enum NodeType : unsigned {
6565
UNPACK_VECTOR,
6666

6767
FCOPYSIGN,
68+
FMAXNUM3,
69+
FMINNUM3,
70+
FMAXIMUM3,
71+
FMINIMUM3,
72+
6873
DYNAMIC_STACKALLOC,
6974
STACKRESTORE,
7075
STACKSAVE,
@@ -283,6 +288,7 @@ class NVPTXTargetLowering : public TargetLowering {
283288

284289
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
285290
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
291+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
286292
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
287293
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
288294
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)