@@ -2172,19 +2172,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21722172}
21732173
21742174// / A generic routine for constructing a tree reduction on a vector operand.
2175- // / This method differs from iterative splitting in DAGTypeLegalizer by
2176- // / progressively grouping elements bottom-up.
2175+ // / This method groups elements bottom-up, progressively building each level.
2176+ // / This approach differs from top-down iterative splitting used in
2177+ // / DAGTypeLegalizer and ExpandReductions.
2178+ // /
2179+ // / Also, the flags on the original reduction operation will be propagated to
2180+ // / each scalar operation.
21772181static SDValue BuildTreeReduction (
21782182 const SmallVector<SDValue> &Elements, EVT EltTy,
21792183 ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
21802184 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2181- // now build the computation graph in place at each level
2185+ // Build the reduction tree at each level, starting with all the elements.
21822186 SmallVector<SDValue> Level = Elements;
2187+
21832188 unsigned OpIdx = 0 ;
21842189 while (Level.size () > 1 ) {
2190+ // Try to reduce this level using the current operator.
21852191 const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21862192
2187- // partially reduce all elements in level
2193+ // Build the next level by partially reducing all elements.
21882194 SmallVector<SDValue> ReducedLevel;
21892195 unsigned I = 0 , E = Level.size ();
21902196 for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2195,18 +2201,23 @@ static SDValue BuildTreeReduction(
21952201 }
21962202
21972203 if (I < E) {
2204+ // We have leftover elements. Why?
2205+
21982206 if (ReducedLevel.empty ()) {
2199- // The current operator requires more inputs than there are operands at
2200- // this level . Pick a smaller operator and retry.
2207+ // ...because this level is now so small that the current operator is
2208+ // too big for it . Pick a smaller operator and retry.
22012209 ++OpIdx;
22022210 assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
22032211 continue ;
22042212 }
22052213
2206- // Otherwise, we just have a remainder, which we push to the next level.
2214+ // ...because the operator's required number of inputs doesn't divide
2215+ // evenly this level. We push this remainder to the next level.
22072216 for (; I < E; ++I)
22082217 ReducedLevel.push_back (Level[I]);
22092218 }
2219+
2220+ // Process the next level.
22102221 Level = ReducedLevel;
22112222 }
22122223
@@ -2222,6 +2233,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22222233 const SDNodeFlags Flags = Op->getFlags ();
22232234 SDValue Vector;
22242235 SDValue Accumulator;
2236+
22252237 if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
22262238 Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
22272239 // special case with accumulator as first arg
@@ -2231,85 +2243,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22312243 // default case
22322244 Vector = Op.getOperand (0 );
22332245 }
2246+
22342247 EVT EltTy = Vector.getValueType ().getVectorElementType ();
22352248 const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
22362249 STI.getPTXVersion () >= 88 ;
22372250
22382251 // A list of SDNode opcodes with equivalent semantics, sorted descending by
22392252 // number of inputs they take.
22402253 SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2241- bool IsReassociatable;
2254+
2255+ // Whether we can lower to scalar operations in an arbitrary order.
2256+ bool IsAssociative;
22422257
22432258 switch (Op->getOpcode ()) {
22442259 case ISD::VECREDUCE_FADD:
22452260 case ISD::VECREDUCE_SEQ_FADD:
22462261 ScalarOps = {{ISD::FADD, 2 }};
2247- IsReassociatable = false ;
2262+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FADD ;
22482263 break ;
22492264 case ISD::VECREDUCE_FMUL:
22502265 case ISD::VECREDUCE_SEQ_FMUL:
22512266 ScalarOps = {{ISD::FMUL, 2 }};
2252- IsReassociatable = false ;
2267+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FMUL ;
22532268 break ;
22542269 case ISD::VECREDUCE_FMAX:
22552270 if (CanUseMinMax3)
22562271 ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
22572272 ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2258- IsReassociatable = false ;
2273+ // Definition of maxNum in IEEE 754 2008 is non-associative, but only
2274+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2275+ // sNaNs.
2276+ IsAssociative = true ;
22592277 break ;
22602278 case ISD::VECREDUCE_FMIN:
22612279 if (CanUseMinMax3)
22622280 ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
22632281 ScalarOps.push_back ({ISD::FMINNUM, 2 });
2264- IsReassociatable = false ;
2282+ // Definition of minNum in IEEE 754 2008 is non-associative, but only
2283+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2284+ // sNaNs.
2285+ IsAssociative = true ;
22652286 break ;
22662287 case ISD::VECREDUCE_FMAXIMUM:
22672288 if (CanUseMinMax3)
22682289 ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
22692290 ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2270- IsReassociatable = false ;
2291+ IsAssociative = true ;
22712292 break ;
22722293 case ISD::VECREDUCE_FMINIMUM:
22732294 if (CanUseMinMax3)
22742295 ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
22752296 ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2276- IsReassociatable = false ;
2297+ IsAssociative = true ;
22772298 break ;
22782299 case ISD::VECREDUCE_ADD:
22792300 ScalarOps = {{ISD::ADD, 2 }};
2280- IsReassociatable = true ;
2301+ IsAssociative = true ;
22812302 break ;
22822303 case ISD::VECREDUCE_MUL:
22832304 ScalarOps = {{ISD::MUL, 2 }};
2284- IsReassociatable = true ;
2305+ IsAssociative = true ;
22852306 break ;
22862307 case ISD::VECREDUCE_UMAX:
22872308 ScalarOps = {{ISD::UMAX, 2 }};
2288- IsReassociatable = true ;
2309+ IsAssociative = true ;
22892310 break ;
22902311 case ISD::VECREDUCE_UMIN:
22912312 ScalarOps = {{ISD::UMIN, 2 }};
2292- IsReassociatable = true ;
2313+ IsAssociative = true ;
22932314 break ;
22942315 case ISD::VECREDUCE_SMAX:
22952316 ScalarOps = {{ISD::SMAX, 2 }};
2296- IsReassociatable = true ;
2317+ IsAssociative = true ;
22972318 break ;
22982319 case ISD::VECREDUCE_SMIN:
22992320 ScalarOps = {{ISD::SMIN, 2 }};
2300- IsReassociatable = true ;
2321+ IsAssociative = true ;
23012322 break ;
23022323 case ISD::VECREDUCE_AND:
23032324 ScalarOps = {{ISD::AND, 2 }};
2304- IsReassociatable = true ;
2325+ IsAssociative = true ;
23052326 break ;
23062327 case ISD::VECREDUCE_OR:
23072328 ScalarOps = {{ISD::OR, 2 }};
2308- IsReassociatable = true ;
2329+ IsAssociative = true ;
23092330 break ;
23102331 case ISD::VECREDUCE_XOR:
23112332 ScalarOps = {{ISD::XOR, 2 }};
2312- IsReassociatable = true ;
2333+ IsAssociative = true ;
23132334 break ;
23142335 default :
23152336 llvm_unreachable (" unhandled vecreduce operation" );
@@ -2326,18 +2347,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23262347 }
23272348
23282349 // Lower to tree reduction.
2329- if (IsReassociatable || Flags. hasAllowReassociation ( )) {
2330- // we don't expect an accumulator for reassociatable vector reduction ops
2350+ if (IsAssociative || allowUnsafeFPMath (DAG. getMachineFunction () )) {
2351+ // we don't expect an accumulator for reassociative vector reduction ops
23312352 assert (!Accumulator && " unexpected accumulator" );
23322353 return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
23332354 }
23342355
23352356 // Lower to sequential reduction.
23362357 for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2358+ // Try to reduce the remaining sequence as much as possible using the
2359+ // current operator.
23372360 assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
23382361 const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
23392362
23402363 if (!Accumulator) {
2364+ // Try to initialize the accumulator using the current operator.
23412365 if (I + DefaultGroupSize <= NumElts) {
23422366 Accumulator = DAG.getNode (
23432367 DefaultScalarOp, DL, EltTy,
0 commit comments