@@ -870,12 +870,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
870870 setTargetDAGCombine (ISD::SETCC);
871871
872872 // Vector reduction operations. These are transformed into a tree evaluation
873- // of nodes which may or may not be legal .
873+ // of nodes which may initially be illegal .
874874 for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
875- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
876- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
877- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
878- VT, Custom);
875+ MVT EltVT = VT.getVectorElementType ();
876+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
877+ EltVT == MVT::f64 ) {
878+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
879+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
880+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
881+ VT, Custom);
882+ } else if (EltVT.isScalarInteger ()) {
883+ setOperationAction (
884+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
885+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
886+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
887+ VT, Custom);
888+ }
879889 }
880890
881891 // Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2213,29 +2223,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
22132223 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
22142224}
22152225
2216- // / A generic routine for constructing a tree reduction for a vector operand.
2226+ // / A generic routine for constructing a tree reduction on a vector operand.
22172227// / This method differs from iterative splitting in DAGTypeLegalizer by
2218- // / first scalarizing the vector and then progressively grouping elements
2219- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2220- // / with different numbers of operands (eg. max3 vs max2).
2228+ // / progressively grouping elements bottom-up.
22212229static SDValue BuildTreeReduction (
2222- const SDValue &VectorOp ,
2230+ const SmallVector< SDValue> &Elements, EVT EltTy ,
22232231 ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
22242232 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2225- EVT VectorTy = VectorOp.getValueType ();
2226- EVT EltTy = VectorTy.getVectorElementType ();
2227- const unsigned NumElts = VectorTy.getVectorNumElements ();
2228-
2229- // scalarize vector
2230- SmallVector<SDValue> Elements (NumElts);
2231- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2232- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2233- DAG.getConstant (I, DL, MVT::i64 ));
2234- }
2235-
22362233 // now build the computation graph in place at each level
22372234 SmallVector<SDValue> Level = Elements;
2238- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2235+ unsigned OpIdx = 0 ;
2236+ while (Level.size () > 1 ) {
22392237 const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
22402238
22412239 // partially reduce all elements in level
@@ -2267,52 +2265,139 @@ static SDValue BuildTreeReduction(
22672265 return *Level.begin ();
22682266}
22692267
2270- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2271- // / serializes it.
2268+ // / Lower reductions to either a sequence of operations or a tree if
2269+ // / reassociations are allowed. This method will use larger operations like
2270+ // / max3/min3 when the target supports them.
22722271SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
22732272 SelectionDAG &DAG) const {
2274- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2275- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2273+ if (DisableFOpTreeReduce)
22762274 return SDValue ();
22772275
2278- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2276+ SDLoc DL (Op);
2277+ const SDNodeFlags Flags = Op->getFlags ();
2278+ const SDValue &Vector = Op.getOperand (0 );
2279+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
22792280 const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
22802281 STI.getPTXVersion () >= 88 ;
2281- SDLoc DL (Op);
2282- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2282+
2283+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2284+ // number of inputs they take.
2285+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2286+ bool IsReassociatable;
2287+
22832288 switch (Op->getOpcode ()) {
22842289 case ISD::VECREDUCE_FADD:
2285- Operators = {{ISD::FADD, 2 }};
2290+ ScalarOps = {{ISD::FADD, 2 }};
2291+ IsReassociatable = false ;
22862292 break ;
22872293 case ISD::VECREDUCE_FMUL:
2288- Operators = {{ISD::FMUL, 2 }};
2294+ ScalarOps = {{ISD::FMUL, 2 }};
2295+ IsReassociatable = false ;
22892296 break ;
22902297 case ISD::VECREDUCE_FMAX:
22912298 if (CanUseMinMax3)
2292- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2293- Operators.push_back ({ISD::FMAXNUM, 2 });
2299+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2300+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2301+ IsReassociatable = false ;
22942302 break ;
22952303 case ISD::VECREDUCE_FMIN:
22962304 if (CanUseMinMax3)
2297- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2298- Operators.push_back ({ISD::FMINNUM, 2 });
2305+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2306+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2307+ IsReassociatable = false ;
22992308 break ;
23002309 case ISD::VECREDUCE_FMAXIMUM:
23012310 if (CanUseMinMax3)
2302- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2303- Operators.push_back ({ISD::FMAXIMUM, 2 });
2311+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2312+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2313+ IsReassociatable = false ;
23042314 break ;
23052315 case ISD::VECREDUCE_FMINIMUM:
23062316 if (CanUseMinMax3)
2307- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2308- Operators.push_back ({ISD::FMINIMUM, 2 });
2317+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2318+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2319+ IsReassociatable = false ;
2320+ break ;
2321+ case ISD::VECREDUCE_ADD:
2322+ ScalarOps = {{ISD::ADD, 2 }};
2323+ IsReassociatable = true ;
2324+ break ;
2325+ case ISD::VECREDUCE_MUL:
2326+ ScalarOps = {{ISD::MUL, 2 }};
2327+ IsReassociatable = true ;
2328+ break ;
2329+ case ISD::VECREDUCE_UMAX:
2330+ ScalarOps = {{ISD::UMAX, 2 }};
2331+ IsReassociatable = true ;
2332+ break ;
2333+ case ISD::VECREDUCE_UMIN:
2334+ ScalarOps = {{ISD::UMIN, 2 }};
2335+ IsReassociatable = true ;
2336+ break ;
2337+ case ISD::VECREDUCE_SMAX:
2338+ ScalarOps = {{ISD::SMAX, 2 }};
2339+ IsReassociatable = true ;
2340+ break ;
2341+ case ISD::VECREDUCE_SMIN:
2342+ ScalarOps = {{ISD::SMIN, 2 }};
2343+ IsReassociatable = true ;
2344+ break ;
2345+ case ISD::VECREDUCE_AND:
2346+ ScalarOps = {{ISD::AND, 2 }};
2347+ IsReassociatable = true ;
2348+ break ;
2349+ case ISD::VECREDUCE_OR:
2350+ ScalarOps = {{ISD::OR, 2 }};
2351+ IsReassociatable = true ;
2352+ break ;
2353+ case ISD::VECREDUCE_XOR:
2354+ ScalarOps = {{ISD::XOR, 2 }};
2355+ IsReassociatable = true ;
23092356 break ;
23102357 default :
23112358 llvm_unreachable (" unhandled vecreduce operation" );
23122359 }
23132360
2314- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2315- DAG);
2361+ EVT VectorTy = Vector.getValueType ();
2362+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2363+
2364+ // scalarize vector
2365+ SmallVector<SDValue> Elements (NumElts);
2366+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2367+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2368+ DAG.getConstant (I, DL, MVT::i64 ));
2369+ }
2370+
2371+ // Lower to tree reduction.
2372+ if (IsReassociatable || Flags.hasAllowReassociation ())
2373+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2374+
2375+ // Lower to sequential reduction.
2376+ SDValue Accumulator;
2377+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2378+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2379+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2380+
2381+ if (!Accumulator) {
2382+ if (I + DefaultGroupSize <= NumElts) {
2383+ Accumulator = DAG.getNode (
2384+ DefaultScalarOp, DL, EltTy,
2385+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2386+ I += DefaultGroupSize;
2387+ }
2388+ }
2389+
2390+ if (Accumulator) {
2391+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2392+ SmallVector<SDValue> Operands = {Accumulator};
2393+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2394+ Operands.push_back (Elements[I + K]);
2395+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2396+ }
2397+ }
2398+ }
2399+
2400+ return Accumulator;
23162401}
23172402
23182403SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3153,6 +3238,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
31533238 case ISD::VECREDUCE_FMIN:
31543239 case ISD::VECREDUCE_FMAXIMUM:
31553240 case ISD::VECREDUCE_FMINIMUM:
3241+ case ISD::VECREDUCE_ADD:
3242+ case ISD::VECREDUCE_MUL:
3243+ case ISD::VECREDUCE_UMAX:
3244+ case ISD::VECREDUCE_UMIN:
3245+ case ISD::VECREDUCE_SMAX:
3246+ case ISD::VECREDUCE_SMIN:
3247+ case ISD::VECREDUCE_AND:
3248+ case ISD::VECREDUCE_OR:
3249+ case ISD::VECREDUCE_XOR:
31563250 return LowerVECREDUCE (Op, DAG);
31573251 case ISD::STORE:
31583252 return LowerSTORE (Op, DAG);
0 commit comments