@@ -841,12 +841,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841841 setTargetDAGCombine (ISD::SETCC);
842842
843843 // Vector reduction operations. These are transformed into a tree evaluation
844- // of nodes which may or may not be legal .
844+ // of nodes which may initially be illegal .
845845 for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
846- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
847- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849- VT, Custom);
846+ MVT EltVT = VT.getVectorElementType ();
847+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
848+ EltVT == MVT::f64 ) {
849+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
850+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
851+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
852+ VT, Custom);
853+ } else if (EltVT.isScalarInteger ()) {
854+ setOperationAction (
855+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
856+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
857+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
858+ VT, Custom);
859+ }
850860 }
851861
852862 // Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2166,29 +2176,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21662176 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
21672177}
21682178
2169- // / A generic routine for constructing a tree reduction for a vector operand.
2179+ // / A generic routine for constructing a tree reduction on a vector operand.
21702180// / This method differs from iterative splitting in DAGTypeLegalizer by
2171- // / first scalarizing the vector and then progressively grouping elements
2172- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2173- // / with different numbers of operands (eg. max3 vs max2).
2181+ // / progressively grouping elements bottom-up.
21742182static SDValue BuildTreeReduction (
2175- const SDValue &VectorOp ,
2183+ const SmallVector< SDValue> &Elements, EVT EltTy ,
21762184 ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
21772185 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2178- EVT VectorTy = VectorOp.getValueType ();
2179- EVT EltTy = VectorTy.getVectorElementType ();
2180- const unsigned NumElts = VectorTy.getVectorNumElements ();
2181-
2182- // scalarize vector
2183- SmallVector<SDValue> Elements (NumElts);
2184- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2185- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2186- DAG.getConstant (I, DL, MVT::i64 ));
2187- }
2188-
21892186 // now build the computation graph in place at each level
21902187 SmallVector<SDValue> Level = Elements;
2191- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2188+ unsigned OpIdx = 0 ;
2189+ while (Level.size () > 1 ) {
21922190 const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21932191
21942192 // partially reduce all elements in level
@@ -2220,52 +2218,139 @@ static SDValue BuildTreeReduction(
22202218 return *Level.begin ();
22212219}
22222220
2223- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2224- // / serializes it.
2221+ // / Lower reductions to either a sequence of operations or a tree if
2222+ // / reassociations are allowed. This method will use larger operations like
2223+ // / max3/min3 when the target supports them.
22252224SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
22262225 SelectionDAG &DAG) const {
2227- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2228- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2226+ if (DisableFOpTreeReduce)
22292227 return SDValue ();
22302228
2231- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2229+ SDLoc DL (Op);
2230+ const SDNodeFlags Flags = Op->getFlags ();
2231+ const SDValue &Vector = Op.getOperand (0 );
2232+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
22322233 const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
22332234 STI.getPTXVersion () >= 88 ;
2234- SDLoc DL (Op);
2235- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2235+
2236+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2237+ // number of inputs they take.
2238+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2239+ bool IsReassociatable;
2240+
22362241 switch (Op->getOpcode ()) {
22372242 case ISD::VECREDUCE_FADD:
2238- Operators = {{ISD::FADD, 2 }};
2243+ ScalarOps = {{ISD::FADD, 2 }};
2244+ IsReassociatable = false ;
22392245 break ;
22402246 case ISD::VECREDUCE_FMUL:
2241- Operators = {{ISD::FMUL, 2 }};
2247+ ScalarOps = {{ISD::FMUL, 2 }};
2248+ IsReassociatable = false ;
22422249 break ;
22432250 case ISD::VECREDUCE_FMAX:
22442251 if (CanUseMinMax3)
2245- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2246- Operators.push_back ({ISD::FMAXNUM, 2 });
2252+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2253+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2254+ IsReassociatable = false ;
22472255 break ;
22482256 case ISD::VECREDUCE_FMIN:
22492257 if (CanUseMinMax3)
2250- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2251- Operators.push_back ({ISD::FMINNUM, 2 });
2258+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2259+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2260+ IsReassociatable = false ;
22522261 break ;
22532262 case ISD::VECREDUCE_FMAXIMUM:
22542263 if (CanUseMinMax3)
2255- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2256- Operators.push_back ({ISD::FMAXIMUM, 2 });
2264+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2265+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2266+ IsReassociatable = false ;
22572267 break ;
22582268 case ISD::VECREDUCE_FMINIMUM:
22592269 if (CanUseMinMax3)
2260- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2261- Operators.push_back ({ISD::FMINIMUM, 2 });
2270+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2271+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2272+ IsReassociatable = false ;
2273+ break ;
2274+ case ISD::VECREDUCE_ADD:
2275+ ScalarOps = {{ISD::ADD, 2 }};
2276+ IsReassociatable = true ;
2277+ break ;
2278+ case ISD::VECREDUCE_MUL:
2279+ ScalarOps = {{ISD::MUL, 2 }};
2280+ IsReassociatable = true ;
2281+ break ;
2282+ case ISD::VECREDUCE_UMAX:
2283+ ScalarOps = {{ISD::UMAX, 2 }};
2284+ IsReassociatable = true ;
2285+ break ;
2286+ case ISD::VECREDUCE_UMIN:
2287+ ScalarOps = {{ISD::UMIN, 2 }};
2288+ IsReassociatable = true ;
2289+ break ;
2290+ case ISD::VECREDUCE_SMAX:
2291+ ScalarOps = {{ISD::SMAX, 2 }};
2292+ IsReassociatable = true ;
2293+ break ;
2294+ case ISD::VECREDUCE_SMIN:
2295+ ScalarOps = {{ISD::SMIN, 2 }};
2296+ IsReassociatable = true ;
2297+ break ;
2298+ case ISD::VECREDUCE_AND:
2299+ ScalarOps = {{ISD::AND, 2 }};
2300+ IsReassociatable = true ;
2301+ break ;
2302+ case ISD::VECREDUCE_OR:
2303+ ScalarOps = {{ISD::OR, 2 }};
2304+ IsReassociatable = true ;
2305+ break ;
2306+ case ISD::VECREDUCE_XOR:
2307+ ScalarOps = {{ISD::XOR, 2 }};
2308+ IsReassociatable = true ;
22622309 break ;
22632310 default :
22642311 llvm_unreachable (" unhandled vecreduce operation" );
22652312 }
22662313
2267- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2268- DAG);
2314+ EVT VectorTy = Vector.getValueType ();
2315+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2316+
2317+ // scalarize vector
2318+ SmallVector<SDValue> Elements (NumElts);
2319+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2320+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2321+ DAG.getConstant (I, DL, MVT::i64 ));
2322+ }
2323+
2324+ // Lower to tree reduction.
2325+ if (IsReassociatable || Flags.hasAllowReassociation ())
2326+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2327+
2328+ // Lower to sequential reduction.
2329+ SDValue Accumulator;
2330+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2331+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2332+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2333+
2334+ if (!Accumulator) {
2335+ if (I + DefaultGroupSize <= NumElts) {
2336+ Accumulator = DAG.getNode (
2337+ DefaultScalarOp, DL, EltTy,
2338+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2339+ I += DefaultGroupSize;
2340+ }
2341+ }
2342+
2343+ if (Accumulator) {
2344+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2345+ SmallVector<SDValue> Operands = {Accumulator};
2346+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2347+ Operands.push_back (Elements[I + K]);
2348+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2349+ }
2350+ }
2351+ }
2352+
2353+ return Accumulator;
22692354}
22702355
22712356SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3062,6 +3147,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30623147 case ISD::VECREDUCE_FMIN:
30633148 case ISD::VECREDUCE_FMAXIMUM:
30643149 case ISD::VECREDUCE_FMINIMUM:
3150+ case ISD::VECREDUCE_ADD:
3151+ case ISD::VECREDUCE_MUL:
3152+ case ISD::VECREDUCE_UMAX:
3153+ case ISD::VECREDUCE_UMIN:
3154+ case ISD::VECREDUCE_SMAX:
3155+ case ISD::VECREDUCE_SMIN:
3156+ case ISD::VECREDUCE_AND:
3157+ case ISD::VECREDUCE_OR:
3158+ case ISD::VECREDUCE_XOR:
30653159 return LowerVECREDUCE (Op, DAG);
30663160 case ISD::STORE:
30673161 return LowerSTORE (Op, DAG);
0 commit comments