@@ -900,6 +900,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
900900 if (STI.allowFP16Math () || STI.hasBF16Math ())
901901 setTargetDAGCombine (ISD::SETCC);
902902
903+ // Vector reduction operations. These may be turned into shuffle or tree
904+ // reductions depending on what instructions are available for each type.
905+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
906+ MVT EltVT = VT.getVectorElementType ();
907+ if (EltVT == MVT::f32 || EltVT == MVT::f64 ) {
908+ setOperationAction ({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
909+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
910+ VT, Custom);
911+ }
912+ }
913+
903914 // Promote fp16 arithmetic if fp16 hardware isn't available or the
904915 // user passed --nvptx-no-fp16-math. The flag is useful because,
905916 // although sm_53+ GPUs have some sort of FP16 support in
@@ -1143,6 +1154,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11431154 MAKE_CASE (NVPTXISD::BFI)
11441155 MAKE_CASE (NVPTXISD::PRMT)
11451156 MAKE_CASE (NVPTXISD::FCOPYSIGN)
1157+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1158+ MAKE_CASE (NVPTXISD::FMINNUM3)
1159+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1160+ MAKE_CASE (NVPTXISD::FMINIMUM3)
11461161 MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
11471162 MAKE_CASE (NVPTXISD::STACKRESTORE)
11481163 MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -1929,6 +1944,124 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
19291944 return getPRMT (A, B, DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG, Mode);
19301945}
19311946
1947+ // / Reduces the elements using the scalar operations provided. The operations
1948+ // / are sorted descending in number of inputs they take. The flags on the
1949+ // / original reduction operation will be propagated to each scalar operation.
1950+ // / Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1951+ // / used in ExpandReductions and SelectionDAG.
1952+ static SDValue buildTreeReduction (
1953+ const SmallVector<SDValue> &Elements, EVT EltTy,
1954+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
1955+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1956+ // Build the reduction tree at each level, starting with all the elements.
1957+ SmallVector<SDValue> Level = Elements;
1958+
1959+ unsigned OpIdx = 0 ;
1960+ while (Level.size () > 1 ) {
1961+ // Try to reduce this level using the current operator.
1962+ const auto [Op, NumInputs] = Ops[OpIdx];
1963+
1964+ // Build the next level by partially reducing all elements.
1965+ SmallVector<SDValue> ReducedLevel;
1966+ unsigned I = 0 , E = Level.size ();
1967+ for (; I + NumInputs <= E; I += NumInputs) {
1968+ // Reduce elements in groups of [NumInputs], as much as possible.
1969+ ReducedLevel.push_back (DAG.getNode (
1970+ Op, DL, EltTy, ArrayRef<SDValue>(Level).slice (I, NumInputs), Flags));
1971+ }
1972+
1973+ if (I < E) {
1974+ // Handle leftover elements.
1975+
1976+ if (ReducedLevel.empty ()) {
1977+ // We didn't reduce anything at this level. We need to pick a smaller
1978+ // operator.
1979+ ++OpIdx;
1980+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
1981+ continue ;
1982+ }
1983+
1984+ // We reduced some things but there's still more left, meaning the
1985+ // operator's number of inputs doesn't evenly divide this level size. Move
1986+ // these elements to the next level.
1987+ for (; I < E; ++I)
1988+ ReducedLevel.push_back (Level[I]);
1989+ }
1990+
1991+ // Process the next level.
1992+ Level = ReducedLevel;
1993+ }
1994+
1995+ return *Level.begin ();
1996+ }
1997+
1998+ // Get scalar reduction opcode
1999+ static ISD::NodeType getScalarOpcodeForReduction (unsigned ReductionOpcode) {
2000+ switch (ReductionOpcode) {
2001+ case ISD::VECREDUCE_FMAX:
2002+ return ISD::FMAXNUM;
2003+ case ISD::VECREDUCE_FMIN:
2004+ return ISD::FMINNUM;
2005+ case ISD::VECREDUCE_FMAXIMUM:
2006+ return ISD::FMAXIMUM;
2007+ case ISD::VECREDUCE_FMINIMUM:
2008+ return ISD::FMINIMUM;
2009+ default :
2010+ llvm_unreachable (" unhandled reduction opcode" );
2011+ }
2012+ }
2013+
2014+ // / Get 3-input scalar reduction opcode
2015+ static std::optional<NVPTXISD::NodeType>
2016+ getScalar3OpcodeForReduction (unsigned ReductionOpcode) {
2017+ switch (ReductionOpcode) {
2018+ case ISD::VECREDUCE_FMAX:
2019+ return NVPTXISD::FMAXNUM3;
2020+ case ISD::VECREDUCE_FMIN:
2021+ return NVPTXISD::FMINNUM3;
2022+ case ISD::VECREDUCE_FMAXIMUM:
2023+ return NVPTXISD::FMAXIMUM3;
2024+ case ISD::VECREDUCE_FMINIMUM:
2025+ return NVPTXISD::FMINIMUM3;
2026+ default :
2027+ return std::nullopt ;
2028+ }
2029+ }
2030+
2031+ // / Lower reductions to either a sequence of operations or a tree if
2032+ // / reassociations are allowed. This method will use larger operations like
2033+ // / max3/min3 when the target supports them.
2034+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2035+ SelectionDAG &DAG) const {
2036+ SDLoc DL (Op);
2037+ const SDNodeFlags Flags = Op->getFlags ();
2038+ SDValue Vector = Op.getOperand (0 );
2039+
2040+ const unsigned Opcode = Op->getOpcode ();
2041+ const EVT EltTy = Vector.getValueType ().getVectorElementType ();
2042+
2043+ // Whether we can use 3-input min/max when expanding the reduction.
2044+ const bool CanUseMinMax3 =
2045+ EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2046+ STI.getPTXVersion () >= 88 &&
2047+ (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2048+ Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2049+
2050+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2051+ // number of inputs they take.
2052+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2053+
2054+ if (auto Opcode3Elem = getScalar3OpcodeForReduction (Opcode);
2055+ CanUseMinMax3 && Opcode3Elem)
2056+ ScalarOps.push_back ({*Opcode3Elem, 3 });
2057+ ScalarOps.push_back ({getScalarOpcodeForReduction (Opcode), 2 });
2058+
2059+ SmallVector<SDValue> Elements;
2060+ DAG.ExtractVectorElements (Vector, Elements);
2061+
2062+ return buildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2063+ }
2064+
19322065SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
19332066 // Handle bitcasting from v2i8 without hitting the default promotion
19342067 // strategy which goes through stack memory.
@@ -2808,6 +2941,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28082941 return LowerVECTOR_SHUFFLE (Op, DAG);
28092942 case ISD::CONCAT_VECTORS:
28102943 return LowerCONCAT_VECTORS (Op, DAG);
2944+ case ISD::VECREDUCE_FMAX:
2945+ case ISD::VECREDUCE_FMIN:
2946+ case ISD::VECREDUCE_FMAXIMUM:
2947+ case ISD::VECREDUCE_FMINIMUM:
2948+ return LowerVECREDUCE (Op, DAG);
28112949 case ISD::STORE:
28122950 return LowerSTORE (Op, DAG);
28132951 case ISD::LOAD:
0 commit comments