@@ -85,6 +85,12 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585 " 1: do it 2: do it aggressively" ),
8686 cl::init(2 ));
8787
88+ static cl::opt<bool > DisableFOpTreeReduce (
89+ " nvptx-disable-fop-tree-reduce" , cl::Hidden,
90+ cl::desc (" NVPTX Specific: don't emit tree reduction for floating-point "
91+ " reduction operations" ),
92+ cl::init(false ));
93+
8894static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32 (
8995 " nvptx-prec-divf32" , cl::Hidden,
9096 cl::desc (" NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -863,6 +869,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
863869 if (STI.allowFP16Math () || STI.hasBF16Math ())
864870 setTargetDAGCombine (ISD::SETCC);
865871
872+ // Vector reduction operations. These are transformed into a tree evaluation
873+ // of nodes which may or may not be legal.
874+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
875+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
876+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
877+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
878+ VT, Custom);
879+ }
880+
866881 // Promote fp16 arithmetic if fp16 hardware isn't available or the
867882 // user passed --nvptx-no-fp16-math. The flag is useful because,
868883 // although sm_53+ GPUs have some sort of FP16 support in
@@ -1120,6 +1135,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11201135 MAKE_CASE (NVPTXISD::BFI)
11211136 MAKE_CASE (NVPTXISD::PRMT)
11221137 MAKE_CASE (NVPTXISD::FCOPYSIGN)
1138+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1139+ MAKE_CASE (NVPTXISD::FMINNUM3)
1140+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1141+ MAKE_CASE (NVPTXISD::FMINIMUM3)
11231142 MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
11241143 MAKE_CASE (NVPTXISD::STACKRESTORE)
11251144 MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2194,6 +2213,108 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21942213 return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
21952214}
21962215
2216+ // / A generic routine for constructing a tree reduction for a vector operand.
2217+ // / This method differs from iterative splitting in DAGTypeLegalizer by
2218+ // / first scalarizing the vector and then progressively grouping elements
2219+ // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2220+ // / with different numbers of operands (eg. max3 vs max2).
2221+ static SDValue BuildTreeReduction (
2222+ const SDValue &VectorOp,
2223+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2224+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2225+ EVT VectorTy = VectorOp.getValueType ();
2226+ EVT EltTy = VectorTy.getVectorElementType ();
2227+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2228+
2229+ // scalarize vector
2230+ SmallVector<SDValue> Elements (NumElts);
2231+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2232+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2233+ DAG.getConstant (I, DL, MVT::i64 ));
2234+ }
2235+
2236+ // now build the computation graph in place at each level
2237+ SmallVector<SDValue> Level = Elements;
2238+ for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2239+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2240+
2241+ // partially reduce all elements in level
2242+ SmallVector<SDValue> ReducedLevel;
2243+ unsigned I = 0 , E = Level.size ();
2244+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2245+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2246+ ReducedLevel.push_back (DAG.getNode (
2247+ DefaultScalarOp, DL, EltTy,
2248+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2249+ }
2250+
2251+ if (I < E) {
2252+ if (ReducedLevel.empty ()) {
2253+ // The current operator requires more inputs than there are operands at
2254+ // this level. Pick a smaller operator and retry.
2255+ ++OpIdx;
2256+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2257+ continue ;
2258+ }
2259+
2260+ // Otherwise, we just have a remainder, which we push to the next level.
2261+ for (; I < E; ++I)
2262+ ReducedLevel.push_back (Level[I]);
2263+ }
2264+ Level = ReducedLevel;
2265+ }
2266+
2267+ return *Level.begin ();
2268+ }
2269+
2270+ // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2271+ // / serializes it.
2272+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2273+ SelectionDAG &DAG) const {
2274+ // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2275+ if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2276+ return SDValue ();
2277+
2278+ EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2279+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2280+ STI.getPTXVersion () >= 88 ;
2281+ SDLoc DL (Op);
2282+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2283+ switch (Op->getOpcode ()) {
2284+ case ISD::VECREDUCE_FADD:
2285+ Operators = {{ISD::FADD, 2 }};
2286+ break ;
2287+ case ISD::VECREDUCE_FMUL:
2288+ Operators = {{ISD::FMUL, 2 }};
2289+ break ;
2290+ case ISD::VECREDUCE_FMAX:
2291+ if (CanUseMinMax3)
2292+ Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2293+ Operators.push_back ({ISD::FMAXNUM, 2 });
2294+ break ;
2295+ case ISD::VECREDUCE_FMIN:
2296+ if (CanUseMinMax3)
2297+ Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2298+ Operators.push_back ({ISD::FMINNUM, 2 });
2299+ break ;
2300+ case ISD::VECREDUCE_FMAXIMUM:
2301+ if (CanUseMinMax3)
2302+ Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2303+ Operators.push_back ({ISD::FMAXIMUM, 2 });
2304+ break ;
2305+ case ISD::VECREDUCE_FMINIMUM:
2306+ if (CanUseMinMax3)
2307+ Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2308+ Operators.push_back ({ISD::FMINIMUM, 2 });
2309+ break ;
2310+ default :
2311+ llvm_unreachable (" unhandled vecreduce operation" );
2312+ }
2313+
2314+ return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2315+ DAG);
2316+ }
2317+
21972318SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
21982319 // Handle bitcasting from v2i8 without hitting the default promotion
21992320 // strategy which goes through stack memory.
@@ -3026,6 +3147,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30263147 return LowerVECTOR_SHUFFLE (Op, DAG);
30273148 case ISD::CONCAT_VECTORS:
30283149 return LowerCONCAT_VECTORS (Op, DAG);
3150+ case ISD::VECREDUCE_FADD:
3151+ case ISD::VECREDUCE_FMUL:
3152+ case ISD::VECREDUCE_FMAX:
3153+ case ISD::VECREDUCE_FMIN:
3154+ case ISD::VECREDUCE_FMAXIMUM:
3155+ case ISD::VECREDUCE_FMINIMUM:
3156+ return LowerVECREDUCE (Op, DAG);
30293157 case ISD::STORE:
30303158 return LowerSTORE (Op, DAG);
30313159 case ISD::LOAD:
0 commit comments