@@ -835,12 +835,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
835835 setTargetDAGCombine (ISD::SETCC);
836836
837837 // Vector reduction operations. These are transformed into a tree evaluation
838- // of nodes which may or may not be legal .
838+ // of nodes which may initially be illegal .
839839 for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
840- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843- VT, Custom);
840+ MVT EltVT = VT.getVectorElementType ();
841+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842+ EltVT == MVT::f64 ) {
843+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846+ VT, Custom);
847+ } else if (EltVT.isScalarInteger ()) {
848+ setOperationAction (
849+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
850+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
851+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
852+ VT, Custom);
853+ }
844854 }
845855
846856 // Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2147,29 +2157,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21472157 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
21482158}
21492159
2150- // / A generic routine for constructing a tree reduction for a vector operand.
2160+ // / A generic routine for constructing a tree reduction on a vector operand.
21512161// / This method differs from iterative splitting in DAGTypeLegalizer by
2152- // / first scalarizing the vector and then progressively grouping elements
2153- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2154- // / with different numbers of operands (eg. max3 vs max2).
2162+ // / progressively grouping elements bottom-up.
21552163static SDValue BuildTreeReduction (
2156- const SDValue &VectorOp ,
2164+ const SmallVector< SDValue> &Elements, EVT EltTy ,
21572165 ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
21582166 const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2159- EVT VectorTy = VectorOp.getValueType ();
2160- EVT EltTy = VectorTy.getVectorElementType ();
2161- const unsigned NumElts = VectorTy.getVectorNumElements ();
2162-
2163- // scalarize vector
2164- SmallVector<SDValue> Elements (NumElts);
2165- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2166- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2167- DAG.getConstant (I, DL, MVT::i64 ));
2168- }
2169-
21702167 // now build the computation graph in place at each level
21712168 SmallVector<SDValue> Level = Elements;
2172- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2169+ unsigned OpIdx = 0 ;
2170+ while (Level.size () > 1 ) {
21732171 const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21742172
21752173 // partially reduce all elements in level
@@ -2201,52 +2199,139 @@ static SDValue BuildTreeReduction(
22012199 return *Level.begin ();
22022200}
22032201
2204- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2205- // / serializes it.
2202+ // / Lower reductions to either a sequence of operations or a tree if
2203+ // / reassociations are allowed. This method will use larger operations like
2204+ // / max3/min3 when the target supports them.
22062205SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
22072206 SelectionDAG &DAG) const {
2208- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2209- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2207+ if (DisableFOpTreeReduce)
22102208 return SDValue ();
22112209
2212- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2210+ SDLoc DL (Op);
2211+ const SDNodeFlags Flags = Op->getFlags ();
2212+ const SDValue &Vector = Op.getOperand (0 );
2213+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
22132214 const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
22142215 STI.getPTXVersion () >= 88 ;
2215- SDLoc DL (Op);
2216- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2216+
2217+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2218+ // number of inputs they take.
2219+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2220+ bool IsReassociatable;
2221+
22172222 switch (Op->getOpcode ()) {
22182223 case ISD::VECREDUCE_FADD:
2219- Operators = {{ISD::FADD, 2 }};
2224+ ScalarOps = {{ISD::FADD, 2 }};
2225+ IsReassociatable = false ;
22202226 break ;
22212227 case ISD::VECREDUCE_FMUL:
2222- Operators = {{ISD::FMUL, 2 }};
2228+ ScalarOps = {{ISD::FMUL, 2 }};
2229+ IsReassociatable = false ;
22232230 break ;
22242231 case ISD::VECREDUCE_FMAX:
22252232 if (CanUseMinMax3)
2226- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2227- Operators.push_back ({ISD::FMAXNUM, 2 });
2233+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2234+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2235+ IsReassociatable = false ;
22282236 break ;
22292237 case ISD::VECREDUCE_FMIN:
22302238 if (CanUseMinMax3)
2231- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2232- Operators.push_back ({ISD::FMINNUM, 2 });
2239+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2240+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2241+ IsReassociatable = false ;
22332242 break ;
22342243 case ISD::VECREDUCE_FMAXIMUM:
22352244 if (CanUseMinMax3)
2236- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2237- Operators.push_back ({ISD::FMAXIMUM, 2 });
2245+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2246+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2247+ IsReassociatable = false ;
22382248 break ;
22392249 case ISD::VECREDUCE_FMINIMUM:
22402250 if (CanUseMinMax3)
2241- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2242- Operators.push_back ({ISD::FMINIMUM, 2 });
2251+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2252+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2253+ IsReassociatable = false ;
2254+ break ;
2255+ case ISD::VECREDUCE_ADD:
2256+ ScalarOps = {{ISD::ADD, 2 }};
2257+ IsReassociatable = true ;
2258+ break ;
2259+ case ISD::VECREDUCE_MUL:
2260+ ScalarOps = {{ISD::MUL, 2 }};
2261+ IsReassociatable = true ;
2262+ break ;
2263+ case ISD::VECREDUCE_UMAX:
2264+ ScalarOps = {{ISD::UMAX, 2 }};
2265+ IsReassociatable = true ;
2266+ break ;
2267+ case ISD::VECREDUCE_UMIN:
2268+ ScalarOps = {{ISD::UMIN, 2 }};
2269+ IsReassociatable = true ;
2270+ break ;
2271+ case ISD::VECREDUCE_SMAX:
2272+ ScalarOps = {{ISD::SMAX, 2 }};
2273+ IsReassociatable = true ;
2274+ break ;
2275+ case ISD::VECREDUCE_SMIN:
2276+ ScalarOps = {{ISD::SMIN, 2 }};
2277+ IsReassociatable = true ;
2278+ break ;
2279+ case ISD::VECREDUCE_AND:
2280+ ScalarOps = {{ISD::AND, 2 }};
2281+ IsReassociatable = true ;
2282+ break ;
2283+ case ISD::VECREDUCE_OR:
2284+ ScalarOps = {{ISD::OR, 2 }};
2285+ IsReassociatable = true ;
2286+ break ;
2287+ case ISD::VECREDUCE_XOR:
2288+ ScalarOps = {{ISD::XOR, 2 }};
2289+ IsReassociatable = true ;
22432290 break ;
22442291 default :
22452292 llvm_unreachable (" unhandled vecreduce operation" );
22462293 }
22472294
2248- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2249- DAG);
2295+ EVT VectorTy = Vector.getValueType ();
2296+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2297+
2298+ // scalarize vector
2299+ SmallVector<SDValue> Elements (NumElts);
2300+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2301+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2302+ DAG.getConstant (I, DL, MVT::i64 ));
2303+ }
2304+
2305+ // Lower to tree reduction.
2306+ if (IsReassociatable || Flags.hasAllowReassociation ())
2307+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2308+
2309+ // Lower to sequential reduction.
2310+ SDValue Accumulator;
2311+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2312+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2313+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2314+
2315+ if (!Accumulator) {
2316+ if (I + DefaultGroupSize <= NumElts) {
2317+ Accumulator = DAG.getNode (
2318+ DefaultScalarOp, DL, EltTy,
2319+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2320+ I += DefaultGroupSize;
2321+ }
2322+ }
2323+
2324+ if (Accumulator) {
2325+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2326+ SmallVector<SDValue> Operands = {Accumulator};
2327+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2328+ Operands.push_back (Elements[I + K]);
2329+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2330+ }
2331+ }
2332+ }
2333+
2334+ return Accumulator;
22502335}
22512336
22522337SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3032,6 +3117,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30323117 case ISD::VECREDUCE_FMIN:
30333118 case ISD::VECREDUCE_FMAXIMUM:
30343119 case ISD::VECREDUCE_FMINIMUM:
3120+ case ISD::VECREDUCE_ADD:
3121+ case ISD::VECREDUCE_MUL:
3122+ case ISD::VECREDUCE_UMAX:
3123+ case ISD::VECREDUCE_UMIN:
3124+ case ISD::VECREDUCE_SMAX:
3125+ case ISD::VECREDUCE_SMIN:
3126+ case ISD::VECREDUCE_AND:
3127+ case ISD::VECREDUCE_OR:
3128+ case ISD::VECREDUCE_XOR:
30353129 return LowerVECREDUCE (Op, DAG);
30363130 case ISD::STORE:
30373131 return LowerSTORE (Op, DAG);
0 commit comments