@@ -900,6 +900,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
900
900
if (STI.allowFP16Math () || STI.hasBF16Math ())
901
901
setTargetDAGCombine (ISD::SETCC);
902
902
903
+ // Vector reduction operations. These may be turned into shuffle or tree
904
+ // reductions depending on what instructions are available for each type.
905
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
906
+ MVT EltVT = VT.getVectorElementType ();
907
+ if (EltVT == MVT::f32 || EltVT == MVT::f64 ) {
908
+ setOperationAction ({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
909
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
910
+ VT, Custom);
911
+ }
912
+ }
913
+
903
914
// Promote fp16 arithmetic if fp16 hardware isn't available or the
904
915
// user passed --nvptx-no-fp16-math. The flag is useful because,
905
916
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1143,6 +1154,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1143
1154
MAKE_CASE (NVPTXISD::BFI)
1144
1155
MAKE_CASE (NVPTXISD::PRMT)
1145
1156
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1157
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1158
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1159
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1160
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1146
1161
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1147
1162
MAKE_CASE (NVPTXISD::STACKRESTORE)
1148
1163
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -1929,6 +1944,124 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1929
1944
return getPRMT (A, B, DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG, Mode);
1930
1945
}
1931
1946
1947
+ // / Reduces the elements using the scalar operations provided. The operations
1948
+ // / are sorted descending in number of inputs they take. The flags on the
1949
+ // / original reduction operation will be propagated to each scalar operation.
1950
+ // / Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1951
+ // / used in ExpandReductions and SelectionDAG.
1952
+ static SDValue buildTreeReduction (
1953
+ const SmallVector<SDValue> &Elements, EVT EltTy,
1954
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
1955
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1956
+ // Build the reduction tree at each level, starting with all the elements.
1957
+ SmallVector<SDValue> Level = Elements;
1958
+
1959
+ unsigned OpIdx = 0 ;
1960
+ while (Level.size () > 1 ) {
1961
+ // Try to reduce this level using the current operator.
1962
+ const auto [Op, NumInputs] = Ops[OpIdx];
1963
+
1964
+ // Build the next level by partially reducing all elements.
1965
+ SmallVector<SDValue> ReducedLevel;
1966
+ unsigned I = 0 , E = Level.size ();
1967
+ for (; I + NumInputs <= E; I += NumInputs) {
1968
+ // Reduce elements in groups of [NumInputs], as much as possible.
1969
+ ReducedLevel.push_back (DAG.getNode (
1970
+ Op, DL, EltTy, ArrayRef<SDValue>(Level).slice (I, NumInputs), Flags));
1971
+ }
1972
+
1973
+ if (I < E) {
1974
+ // Handle leftover elements.
1975
+
1976
+ if (ReducedLevel.empty ()) {
1977
+ // We didn't reduce anything at this level. We need to pick a smaller
1978
+ // operator.
1979
+ ++OpIdx;
1980
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
1981
+ continue ;
1982
+ }
1983
+
1984
+ // We reduced some things but there's still more left, meaning the
1985
+ // operator's number of inputs doesn't evenly divide this level size. Move
1986
+ // these elements to the next level.
1987
+ for (; I < E; ++I)
1988
+ ReducedLevel.push_back (Level[I]);
1989
+ }
1990
+
1991
+ // Process the next level.
1992
+ Level = ReducedLevel;
1993
+ }
1994
+
1995
+ return *Level.begin ();
1996
+ }
1997
+
1998
+ // Get scalar reduction opcode
1999
+ static ISD::NodeType getScalarOpcodeForReduction (unsigned ReductionOpcode) {
2000
+ switch (ReductionOpcode) {
2001
+ case ISD::VECREDUCE_FMAX:
2002
+ return ISD::FMAXNUM;
2003
+ case ISD::VECREDUCE_FMIN:
2004
+ return ISD::FMINNUM;
2005
+ case ISD::VECREDUCE_FMAXIMUM:
2006
+ return ISD::FMAXIMUM;
2007
+ case ISD::VECREDUCE_FMINIMUM:
2008
+ return ISD::FMINIMUM;
2009
+ default :
2010
+ llvm_unreachable (" unhandled reduction opcode" );
2011
+ }
2012
+ }
2013
+
2014
+ // / Get 3-input scalar reduction opcode
2015
+ static std::optional<NVPTXISD::NodeType>
2016
+ getScalar3OpcodeForReduction (unsigned ReductionOpcode) {
2017
+ switch (ReductionOpcode) {
2018
+ case ISD::VECREDUCE_FMAX:
2019
+ return NVPTXISD::FMAXNUM3;
2020
+ case ISD::VECREDUCE_FMIN:
2021
+ return NVPTXISD::FMINNUM3;
2022
+ case ISD::VECREDUCE_FMAXIMUM:
2023
+ return NVPTXISD::FMAXIMUM3;
2024
+ case ISD::VECREDUCE_FMINIMUM:
2025
+ return NVPTXISD::FMINIMUM3;
2026
+ default :
2027
+ return std::nullopt;
2028
+ }
2029
+ }
2030
+
2031
+ // / Lower reductions to either a sequence of operations or a tree if
2032
+ // / reassociations are allowed. This method will use larger operations like
2033
+ // / max3/min3 when the target supports them.
2034
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2035
+ SelectionDAG &DAG) const {
2036
+ SDLoc DL (Op);
2037
+ const SDNodeFlags Flags = Op->getFlags ();
2038
+ SDValue Vector = Op.getOperand (0 );
2039
+
2040
+ const unsigned Opcode = Op->getOpcode ();
2041
+ const EVT EltTy = Vector.getValueType ().getVectorElementType ();
2042
+
2043
+ // Whether we can use 3-input min/max when expanding the reduction.
2044
+ const bool CanUseMinMax3 =
2045
+ EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2046
+ STI.getPTXVersion () >= 88 &&
2047
+ (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2048
+ Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2049
+
2050
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2051
+ // number of inputs they take.
2052
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2053
+
2054
+ if (auto Opcode3Elem = getScalar3OpcodeForReduction (Opcode);
2055
+ CanUseMinMax3 && Opcode3Elem)
2056
+ ScalarOps.push_back ({*Opcode3Elem, 3 });
2057
+ ScalarOps.push_back ({getScalarOpcodeForReduction (Opcode), 2 });
2058
+
2059
+ SmallVector<SDValue> Elements;
2060
+ DAG.ExtractVectorElements (Vector, Elements);
2061
+
2062
+ return buildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2063
+ }
2064
+
1932
2065
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
1933
2066
// Handle bitcasting from v2i8 without hitting the default promotion
1934
2067
// strategy which goes through stack memory.
@@ -2808,6 +2941,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2808
2941
return LowerVECTOR_SHUFFLE (Op, DAG);
2809
2942
case ISD::CONCAT_VECTORS:
2810
2943
return LowerCONCAT_VECTORS (Op, DAG);
2944
+ case ISD::VECREDUCE_FMAX:
2945
+ case ISD::VECREDUCE_FMIN:
2946
+ case ISD::VECREDUCE_FMAXIMUM:
2947
+ case ISD::VECREDUCE_FMINIMUM:
2948
+ return LowerVECREDUCE (Op, DAG);
2811
2949
case ISD::STORE:
2812
2950
return LowerSTORE (Op, DAG);
2813
2951
case ISD::LOAD:
0 commit comments