@@ -847,6 +847,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
847
847
if (STI.allowFP16Math () || STI.hasBF16Math ())
848
848
setTargetDAGCombine (ISD::SETCC);
849
849
850
+ // Vector reduction operations. These may be turned into sequential, shuffle,
851
+ // or tree reductions depending on what instructions are available for each
852
+ // type.
853
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
854
+ MVT EltVT = VT.getVectorElementType ();
855
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
856
+ EltVT == MVT::f64 ) {
857
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
858
+ ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
859
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
860
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
861
+ VT, Custom);
862
+ } else if (EltVT.isScalarInteger ()) {
863
+ setOperationAction (
864
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
865
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
866
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
867
+ VT, Custom);
868
+ }
869
+ }
870
+
850
871
// Promote fp16 arithmetic if fp16 hardware isn't available or the
851
872
// user passed --nvptx-no-fp16-math. The flag is useful because,
852
873
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1091,6 +1112,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1091
1112
MAKE_CASE (NVPTXISD::BFI)
1092
1113
MAKE_CASE (NVPTXISD::PRMT)
1093
1114
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1115
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1116
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1117
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1118
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1094
1119
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1095
1120
MAKE_CASE (NVPTXISD::STACKRESTORE)
1096
1121
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2060,6 +2085,261 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2060
2085
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2061
2086
}
2062
2087
2088
+ // / A generic routine for constructing a tree reduction on a vector operand.
2089
+ // / This method groups elements bottom-up, progressively building each level.
2090
+ // / Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2091
+ // / adjacent elements are combined first, leading to shorter live ranges. This
2092
+ // / approach makes the most sense if the shuffle reduction would use the same
2093
+ // / amount of registers.
2094
+ // /
2095
+ // / The flags on the original reduction operation will be propagated to
2096
+ // / each scalar operation.
2097
+ static SDValue BuildTreeReduction (
2098
+ const SmallVector<SDValue> &Elements, EVT EltTy,
2099
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2100
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2101
+ // Build the reduction tree at each level, starting with all the elements.
2102
+ SmallVector<SDValue> Level = Elements;
2103
+
2104
+ unsigned OpIdx = 0 ;
2105
+ while (Level.size () > 1 ) {
2106
+ // Try to reduce this level using the current operator.
2107
+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2108
+
2109
+ // Build the next level by partially reducing all elements.
2110
+ SmallVector<SDValue> ReducedLevel;
2111
+ unsigned I = 0 , E = Level.size ();
2112
+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2113
+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2114
+ ReducedLevel.push_back (DAG.getNode (
2115
+ DefaultScalarOp, DL, EltTy,
2116
+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2117
+ }
2118
+
2119
+ if (I < E) {
2120
+ // Handle leftover elements.
2121
+
2122
+ if (ReducedLevel.empty ()) {
2123
+ // We didn't reduce anything at this level. We need to pick a smaller
2124
+ // operator.
2125
+ ++OpIdx;
2126
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2127
+ continue ;
2128
+ }
2129
+
2130
+ // We reduced some things but there's still more left, meaning the
2131
+ // operator's number of inputs doesn't evenly divide this level size. Move
2132
+ // these elements to the next level.
2133
+ for (; I < E; ++I)
2134
+ ReducedLevel.push_back (Level[I]);
2135
+ }
2136
+
2137
+ // Process the next level.
2138
+ Level = ReducedLevel;
2139
+ }
2140
+
2141
+ return *Level.begin ();
2142
+ }
2143
+
2144
+ // / Lower reductions to either a sequence of operations or a tree if
2145
+ // / reassociations are allowed. This method will use larger operations like
2146
+ // / max3/min3 when the target supports them.
2147
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2148
+ SelectionDAG &DAG) const {
2149
+ SDLoc DL (Op);
2150
+ const SDNodeFlags Flags = Op->getFlags ();
2151
+ SDValue Vector;
2152
+ SDValue Accumulator;
2153
+
2154
+ if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2155
+ Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2156
+ // special case with accumulator as first arg
2157
+ Accumulator = Op.getOperand (0 );
2158
+ Vector = Op.getOperand (1 );
2159
+ } else {
2160
+ // default case
2161
+ Vector = Op.getOperand (0 );
2162
+ }
2163
+
2164
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2165
+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2166
+ STI.getPTXVersion () >= 88 ;
2167
+
2168
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2169
+ // number of inputs they take.
2170
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2171
+
2172
+ // Whether we can lower to scalar operations in an arbitrary order.
2173
+ bool IsAssociative = allowUnsafeFPMath (DAG.getMachineFunction ());
2174
+
2175
+ // Whether the data type and operation can be represented with fewer ops and
2176
+ // registers in a shuffle reduction.
2177
+ bool PrefersShuffle;
2178
+
2179
+ switch (Op->getOpcode ()) {
2180
+ case ISD::VECREDUCE_FADD:
2181
+ case ISD::VECREDUCE_SEQ_FADD:
2182
+ ScalarOps = {{ISD::FADD, 2 }};
2183
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FADD;
2184
+ // Prefer add.{f16,bf16,f32}x2 for v2{f16,bf16,f32}
2185
+ PrefersShuffle =
2186
+ EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32 ;
2187
+ break ;
2188
+ case ISD::VECREDUCE_FMUL:
2189
+ case ISD::VECREDUCE_SEQ_FMUL:
2190
+ ScalarOps = {{ISD::FMUL, 2 }};
2191
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FMUL;
2192
+ // Prefer mul.{f16,bf16,f32}x2 for v2{f16,bf16,f32}
2193
+ PrefersShuffle =
2194
+ EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32 ;
2195
+ break ;
2196
+ case ISD::VECREDUCE_FMAX:
2197
+ if (CanUseMinMax3)
2198
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2199
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2200
+ // Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2201
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2202
+ IsAssociative |= Flags.hasAllowReassociation ();
2203
+ PrefersShuffle = false ;
2204
+ break ;
2205
+ case ISD::VECREDUCE_FMIN:
2206
+ if (CanUseMinMax3)
2207
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2208
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2209
+ // Definition of minNum in IEEE 754 2008 is non-associative due to handling
2210
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2211
+ IsAssociative |= Flags.hasAllowReassociation ();
2212
+ PrefersShuffle = false ;
2213
+ break ;
2214
+ case ISD::VECREDUCE_FMAXIMUM:
2215
+ if (CanUseMinMax3) {
2216
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2217
+ // Can't use fmax3 in shuffle reduction
2218
+ PrefersShuffle = false ;
2219
+ } else {
2220
+ // Prefer max.{,b}f16x2 for v2{,b}f16
2221
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2222
+ }
2223
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2224
+ IsAssociative = true ;
2225
+ break ;
2226
+ case ISD::VECREDUCE_FMINIMUM:
2227
+ if (CanUseMinMax3) {
2228
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2229
+ // Can't use fmin3 in shuffle reduction
2230
+ PrefersShuffle = false ;
2231
+ } else {
2232
+ // Prefer min.{,b}f16x2 for v2{,b}f16
2233
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2234
+ }
2235
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2236
+ IsAssociative = true ;
2237
+ break ;
2238
+ case ISD::VECREDUCE_ADD:
2239
+ ScalarOps = {{ISD::ADD, 2 }};
2240
+ IsAssociative = true ;
2241
+ // Prefer add.{s,u}16x2 for v2i16
2242
+ PrefersShuffle = EltTy == MVT::i16 ;
2243
+ break ;
2244
+ case ISD::VECREDUCE_MUL:
2245
+ ScalarOps = {{ISD::MUL, 2 }};
2246
+ IsAssociative = true ;
2247
+ // Integer multiply doesn't support packed types
2248
+ PrefersShuffle = false ;
2249
+ break ;
2250
+ case ISD::VECREDUCE_UMAX:
2251
+ ScalarOps = {{ISD::UMAX, 2 }};
2252
+ IsAssociative = true ;
2253
+ // Prefer max.u16x2 for v2i16
2254
+ PrefersShuffle = EltTy == MVT::i16 ;
2255
+ break ;
2256
+ case ISD::VECREDUCE_UMIN:
2257
+ ScalarOps = {{ISD::UMIN, 2 }};
2258
+ IsAssociative = true ;
2259
+ // Prefer min.u16x2 for v2i16
2260
+ PrefersShuffle = EltTy == MVT::i16 ;
2261
+ break ;
2262
+ case ISD::VECREDUCE_SMAX:
2263
+ ScalarOps = {{ISD::SMAX, 2 }};
2264
+ IsAssociative = true ;
2265
+ // Prefer max.s16x2 for v2i16
2266
+ PrefersShuffle = EltTy == MVT::i16 ;
2267
+ break ;
2268
+ case ISD::VECREDUCE_SMIN:
2269
+ ScalarOps = {{ISD::SMIN, 2 }};
2270
+ IsAssociative = true ;
2271
+ // Prefer min.s16x2 for v2i16
2272
+ PrefersShuffle = EltTy == MVT::i16 ;
2273
+ break ;
2274
+ case ISD::VECREDUCE_AND:
2275
+ ScalarOps = {{ISD::AND, 2 }};
2276
+ IsAssociative = true ;
2277
+ // Prefer and.b32 for v2i16.
2278
+ PrefersShuffle = EltTy == MVT::i16 ;
2279
+ break ;
2280
+ case ISD::VECREDUCE_OR:
2281
+ ScalarOps = {{ISD::OR, 2 }};
2282
+ IsAssociative = true ;
2283
+ // Prefer or.b32 for v2i16.
2284
+ PrefersShuffle = EltTy == MVT::i16 ;
2285
+ break ;
2286
+ case ISD::VECREDUCE_XOR:
2287
+ ScalarOps = {{ISD::XOR, 2 }};
2288
+ IsAssociative = true ;
2289
+ // Prefer xor.b32 for v2i16.
2290
+ PrefersShuffle = EltTy == MVT::i16 ;
2291
+ break ;
2292
+ default :
2293
+ llvm_unreachable (" unhandled vecreduce operation" );
2294
+ }
2295
+
2296
+ // We don't expect an accumulator for reassociative vector reduction ops.
2297
+ assert ((!IsAssociative || !Accumulator) && " unexpected accumulator" );
2298
+
2299
+ // If shuffle reduction is preferred, leave it to SelectionDAG.
2300
+ if (IsAssociative && PrefersShuffle)
2301
+ return SDValue ();
2302
+
2303
+ // Otherwise, handle the reduction here.
2304
+ SmallVector<SDValue> Elements;
2305
+ DAG.ExtractVectorElements (Vector, Elements);
2306
+
2307
+ // Lower to tree reduction.
2308
+ if (IsAssociative)
2309
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2310
+
2311
+ // Lower to sequential reduction.
2312
+ EVT VectorTy = Vector.getValueType ();
2313
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2314
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2315
+ // Try to reduce the remaining sequence as much as possible using the
2316
+ // current operator.
2317
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2318
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2319
+
2320
+ if (!Accumulator) {
2321
+ // Try to initialize the accumulator using the current operator.
2322
+ if (I + DefaultGroupSize <= NumElts) {
2323
+ Accumulator = DAG.getNode (
2324
+ DefaultScalarOp, DL, EltTy,
2325
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2326
+ I += DefaultGroupSize;
2327
+ }
2328
+ }
2329
+
2330
+ if (Accumulator) {
2331
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2332
+ SmallVector<SDValue> Operands = {Accumulator};
2333
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2334
+ Operands.push_back (Elements[I + K]);
2335
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2336
+ }
2337
+ }
2338
+ }
2339
+
2340
+ return Accumulator;
2341
+ }
2342
+
2063
2343
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2064
2344
// Handle bitcasting from v2i8 without hitting the default promotion
2065
2345
// strategy which goes through stack memory.
@@ -2903,6 +3183,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2903
3183
return LowerVECTOR_SHUFFLE (Op, DAG);
2904
3184
case ISD::CONCAT_VECTORS:
2905
3185
return LowerCONCAT_VECTORS (Op, DAG);
3186
+ case ISD::VECREDUCE_FADD:
3187
+ case ISD::VECREDUCE_FMUL:
3188
+ case ISD::VECREDUCE_SEQ_FADD:
3189
+ case ISD::VECREDUCE_SEQ_FMUL:
3190
+ case ISD::VECREDUCE_FMAX:
3191
+ case ISD::VECREDUCE_FMIN:
3192
+ case ISD::VECREDUCE_FMAXIMUM:
3193
+ case ISD::VECREDUCE_FMINIMUM:
3194
+ case ISD::VECREDUCE_ADD:
3195
+ case ISD::VECREDUCE_MUL:
3196
+ case ISD::VECREDUCE_UMAX:
3197
+ case ISD::VECREDUCE_UMIN:
3198
+ case ISD::VECREDUCE_SMAX:
3199
+ case ISD::VECREDUCE_SMIN:
3200
+ case ISD::VECREDUCE_AND:
3201
+ case ISD::VECREDUCE_OR:
3202
+ case ISD::VECREDUCE_XOR:
3203
+ return LowerVECREDUCE (Op, DAG);
2906
3204
case ISD::STORE:
2907
3205
return LowerSTORE (Op, DAG);
2908
3206
case ISD::LOAD:
0 commit comments