@@ -2225,19 +2225,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
22252225}
22262226
22272227// / A generic routine for constructing a tree reduction on a vector operand.
2228- // / This method differs from iterative splitting in DAGTypeLegalizer by
2229- // / progressively grouping elements bottom-up.
2228+ // / This method groups elements bottom-up, progressively building each level.
2229+ // / This approach differs from top-down iterative splitting used in
2230+ // / DAGTypeLegalizer and ExpandReductions.
2231+ // /
2232+ // / Also, the flags on the original reduction operation will be propagated to
2233+ // / each scalar operation.
22302234static SDValue BuildTreeReduction (
22312235 const SmallVector<SDValue> &Elements, EVT EltTy,
22322236 ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
22332237 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2234- // now build the computation graph in place at each level
2238+ // Build the reduction tree at each level, starting with all the elements.
22352239 SmallVector<SDValue> Level = Elements;
2240+
22362241 unsigned OpIdx = 0 ;
22372242 while (Level.size () > 1 ) {
2243+ // Try to reduce this level using the current operator.
22382244 const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
22392245
2240- // partially reduce all elements in level
2246+ // Build the next level by partially reducing all elements.
22412247 SmallVector<SDValue> ReducedLevel;
22422248 unsigned I = 0 , E = Level.size ();
22432249 for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2248,18 +2254,23 @@ static SDValue BuildTreeReduction(
22482254 }
22492255
22502256 if (I < E) {
2257+ // We have leftover elements. Why?
2258+
22512259 if (ReducedLevel.empty ()) {
2252- // The current operator requires more inputs than there are operands at
2253- // this level . Pick a smaller operator and retry.
2260+ // ...because this level is now so small that the current operator is
2261+ // too big for it . Pick a smaller operator and retry.
22542262 ++OpIdx;
22552263 assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
22562264 continue ;
22572265 }
22582266
2259- // Otherwise, we just have a remainder, which we push to the next level.
2267+ // ...because the operator's required number of inputs doesn't divide
2268+ // evenly this level. We push this remainder to the next level.
22602269 for (; I < E; ++I)
22612270 ReducedLevel.push_back (Level[I]);
22622271 }
2272+
2273+ // Process the next level.
22632274 Level = ReducedLevel;
22642275 }
22652276
@@ -2275,6 +2286,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22752286 const SDNodeFlags Flags = Op->getFlags ();
22762287 SDValue Vector;
22772288 SDValue Accumulator;
2289+
22782290 if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
22792291 Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
22802292 // special case with accumulator as first arg
@@ -2284,85 +2296,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22842296 // default case
22852297 Vector = Op.getOperand (0 );
22862298 }
2299+
22872300 EVT EltTy = Vector.getValueType ().getVectorElementType ();
22882301 const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
22892302 STI.getPTXVersion () >= 88 ;
22902303
22912304 // A list of SDNode opcodes with equivalent semantics, sorted descending by
22922305 // number of inputs they take.
22932306 SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2294- bool IsReassociatable;
2307+
2308+ // Whether we can lower to scalar operations in an arbitrary order.
2309+ bool IsAssociative;
22952310
22962311 switch (Op->getOpcode ()) {
22972312 case ISD::VECREDUCE_FADD:
22982313 case ISD::VECREDUCE_SEQ_FADD:
22992314 ScalarOps = {{ISD::FADD, 2 }};
2300- IsReassociatable = false ;
2315+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FADD ;
23012316 break ;
23022317 case ISD::VECREDUCE_FMUL:
23032318 case ISD::VECREDUCE_SEQ_FMUL:
23042319 ScalarOps = {{ISD::FMUL, 2 }};
2305- IsReassociatable = false ;
2320+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FMUL ;
23062321 break ;
23072322 case ISD::VECREDUCE_FMAX:
23082323 if (CanUseMinMax3)
23092324 ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
23102325 ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2311- IsReassociatable = false ;
2326+ // Definition of maxNum in IEEE 754 2008 is non-associative, but only
2327+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2328+ // sNaNs.
2329+ IsAssociative = true ;
23122330 break ;
23132331 case ISD::VECREDUCE_FMIN:
23142332 if (CanUseMinMax3)
23152333 ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
23162334 ScalarOps.push_back ({ISD::FMINNUM, 2 });
2317- IsReassociatable = false ;
2335+ // Definition of minNum in IEEE 754 2008 is non-associative, but only
2336+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2337+ // sNaNs.
2338+ IsAssociative = true ;
23182339 break ;
23192340 case ISD::VECREDUCE_FMAXIMUM:
23202341 if (CanUseMinMax3)
23212342 ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
23222343 ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2323- IsReassociatable = false ;
2344+ IsAssociative = true ;
23242345 break ;
23252346 case ISD::VECREDUCE_FMINIMUM:
23262347 if (CanUseMinMax3)
23272348 ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
23282349 ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2329- IsReassociatable = false ;
2350+ IsAssociative = true ;
23302351 break ;
23312352 case ISD::VECREDUCE_ADD:
23322353 ScalarOps = {{ISD::ADD, 2 }};
2333- IsReassociatable = true ;
2354+ IsAssociative = true ;
23342355 break ;
23352356 case ISD::VECREDUCE_MUL:
23362357 ScalarOps = {{ISD::MUL, 2 }};
2337- IsReassociatable = true ;
2358+ IsAssociative = true ;
23382359 break ;
23392360 case ISD::VECREDUCE_UMAX:
23402361 ScalarOps = {{ISD::UMAX, 2 }};
2341- IsReassociatable = true ;
2362+ IsAssociative = true ;
23422363 break ;
23432364 case ISD::VECREDUCE_UMIN:
23442365 ScalarOps = {{ISD::UMIN, 2 }};
2345- IsReassociatable = true ;
2366+ IsAssociative = true ;
23462367 break ;
23472368 case ISD::VECREDUCE_SMAX:
23482369 ScalarOps = {{ISD::SMAX, 2 }};
2349- IsReassociatable = true ;
2370+ IsAssociative = true ;
23502371 break ;
23512372 case ISD::VECREDUCE_SMIN:
23522373 ScalarOps = {{ISD::SMIN, 2 }};
2353- IsReassociatable = true ;
2374+ IsAssociative = true ;
23542375 break ;
23552376 case ISD::VECREDUCE_AND:
23562377 ScalarOps = {{ISD::AND, 2 }};
2357- IsReassociatable = true ;
2378+ IsAssociative = true ;
23582379 break ;
23592380 case ISD::VECREDUCE_OR:
23602381 ScalarOps = {{ISD::OR, 2 }};
2361- IsReassociatable = true ;
2382+ IsAssociative = true ;
23622383 break ;
23632384 case ISD::VECREDUCE_XOR:
23642385 ScalarOps = {{ISD::XOR, 2 }};
2365- IsReassociatable = true ;
2386+ IsAssociative = true ;
23662387 break ;
23672388 default :
23682389 llvm_unreachable (" unhandled vecreduce operation" );
@@ -2379,18 +2400,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23792400 }
23802401
23812402 // Lower to tree reduction.
2382- if (IsReassociatable || Flags. hasAllowReassociation ( )) {
2383- // we don't expect an accumulator for reassociatable vector reduction ops
2403+ if (IsAssociative || allowUnsafeFPMath (DAG. getMachineFunction () )) {
2404+ // we don't expect an accumulator for reassociative vector reduction ops
23842405 assert (!Accumulator && " unexpected accumulator" );
23852406 return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
23862407 }
23872408
23882409 // Lower to sequential reduction.
23892410 for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2411+ // Try to reduce the remaining sequence as much as possible using the
2412+ // current operator.
23902413 assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
23912414 const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
23922415
23932416 if (!Accumulator) {
2417+ // Try to initialize the accumulator using the current operator.
23942418 if (I + DefaultGroupSize <= NumElts) {
23952419 Accumulator = DAG.getNode (
23962420 DefaultScalarOp, DL, EltTy,
0 commit comments