@@ -850,6 +850,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
850
850
if (STI.allowFP16Math () || STI.hasBF16Math ())
851
851
setTargetDAGCombine (ISD::SETCC);
852
852
853
+ // Vector reduction operations. These may be turned into shuffle or tree
854
+ // reductions depending on what instructions are available for each type.
855
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
856
+ MVT EltVT = VT.getVectorElementType ();
857
+ if (EltVT == MVT::f32 || EltVT == MVT::f64 ) {
858
+ setOperationAction ({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
859
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
860
+ VT, Custom);
861
+ }
862
+ }
863
+
853
864
// Promote fp16 arithmetic if fp16 hardware isn't available or the
854
865
// user passed --nvptx-no-fp16-math. The flag is useful because,
855
866
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1093,6 +1104,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1093
1104
MAKE_CASE (NVPTXISD::BFI)
1094
1105
MAKE_CASE (NVPTXISD::PRMT)
1095
1106
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1107
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1108
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1109
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1110
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1096
1111
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1097
1112
MAKE_CASE (NVPTXISD::STACKRESTORE)
1098
1113
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -1900,6 +1915,125 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
1900
1915
return getPRMT (A, B, DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG, Mode);
1901
1916
}
1902
1917
1918
+ // / Reduces the elements using the scalar operations provided. The operations
1919
+ // / are sorted descending in number of inputs they take. The flags on the
1920
+ // / original reduction operation will be propagated to each scalar operation.
1921
+ // / Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1922
+ // / used in ExpandReductions and SelectionDAG.
1923
+ static SDValue buildTreeReduction (
1924
+ const SmallVector<SDValue> &Elements, EVT EltTy,
1925
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
1926
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1927
+ // Build the reduction tree at each level, starting with all the elements.
1928
+ SmallVector<SDValue> Level = Elements;
1929
+
1930
+ unsigned OpIdx = 0 ;
1931
+ while (Level.size () > 1 ) {
1932
+ // Try to reduce this level using the current operator.
1933
+ const auto [Op, NumInputs] = Ops[OpIdx];
1934
+
1935
+ // Build the next level by partially reducing all elements.
1936
+ SmallVector<SDValue> ReducedLevel;
1937
+ unsigned I = 0 , E = Level.size ();
1938
+ for (; I + NumInputs <= E; I += NumInputs) {
1939
+ // Reduce elements in groups of [NumInputs], as much as possible.
1940
+ ReducedLevel.push_back (DAG.getNode (
1941
+ Op, DL, EltTy, ArrayRef<SDValue>(Level).slice (I, NumInputs), Flags));
1942
+ }
1943
+
1944
+ if (I < E) {
1945
+ // Handle leftover elements.
1946
+
1947
+ if (ReducedLevel.empty ()) {
1948
+ // We didn't reduce anything at this level. We need to pick a smaller
1949
+ // operator.
1950
+ ++OpIdx;
1951
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
1952
+ continue ;
1953
+ }
1954
+
1955
+ // We reduced some things but there's still more left, meaning the
1956
+ // operator's number of inputs doesn't evenly divide this level size. Move
1957
+ // these elements to the next level.
1958
+ for (; I < E; ++I)
1959
+ ReducedLevel.push_back (Level[I]);
1960
+ }
1961
+
1962
+ // Process the next level.
1963
+ Level = ReducedLevel;
1964
+ }
1965
+
1966
+ return *Level.begin ();
1967
+ }
1968
+
1969
+ // Get scalar reduction opcode
1970
+ static ISD::NodeType getScalarOpcodeForReduction (unsigned ReductionOpcode) {
1971
+ switch (ReductionOpcode) {
1972
+ case ISD::VECREDUCE_FMAX:
1973
+ return ISD::FMAXNUM;
1974
+ case ISD::VECREDUCE_FMIN:
1975
+ return ISD::FMINNUM;
1976
+ case ISD::VECREDUCE_FMAXIMUM:
1977
+ return ISD::FMAXIMUM;
1978
+ case ISD::VECREDUCE_FMINIMUM:
1979
+ return ISD::FMINIMUM;
1980
+ default :
1981
+ llvm_unreachable (" unhandled reduction opcode" );
1982
+ }
1983
+ }
1984
+
1985
+ // / Get 3-input scalar reduction opcode
1986
+ static std::optional<NVPTXISD::NodeType>
1987
+ getScalar3OpcodeForReduction (unsigned ReductionOpcode) {
1988
+ switch (ReductionOpcode) {
1989
+ case ISD::VECREDUCE_FMAX:
1990
+ return NVPTXISD::FMAXNUM3;
1991
+ case ISD::VECREDUCE_FMIN:
1992
+ return NVPTXISD::FMINNUM3;
1993
+ case ISD::VECREDUCE_FMAXIMUM:
1994
+ return NVPTXISD::FMAXIMUM3;
1995
+ case ISD::VECREDUCE_FMINIMUM:
1996
+ return NVPTXISD::FMINIMUM3;
1997
+ default :
1998
+ return std::nullopt;
1999
+ }
2000
+ }
2001
+
2002
+ // / Lower reductions to either a sequence of operations or a tree if
2003
+ // / reassociations are allowed. This method will use larger operations like
2004
+ // / max3/min3 when the target supports them.
2005
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2006
+ SelectionDAG &DAG) const {
2007
+ SDLoc DL (Op);
2008
+ const SDNodeFlags Flags = Op->getFlags ();
2009
+ SDValue Vector = Op.getOperand (0 );
2010
+
2011
+ const unsigned Opcode = Op->getOpcode ();
2012
+ const EVT EltTy = Vector.getValueType ().getVectorElementType ();
2013
+
2014
+ // Whether we can use 3-input min/max when expanding the reduction.
2015
+ const bool CanUseMinMax3 =
2016
+ EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2017
+ STI.getPTXVersion () >= 88 &&
2018
+ (Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2019
+ Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2020
+
2021
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2022
+ // number of inputs they take.
2023
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2024
+
2025
+ if (auto Opcode3Elem = getScalar3OpcodeForReduction (Opcode);
2026
+ CanUseMinMax3 && Opcode3Elem)
2027
+ ScalarOps.push_back ({*Opcode3Elem, 3 });
2028
+ ScalarOps.push_back ({getScalarOpcodeForReduction (Opcode), 2 });
2029
+
2030
+ // Otherwise, handle the reduction here.
2031
+ SmallVector<SDValue> Elements;
2032
+ DAG.ExtractVectorElements (Vector, Elements);
2033
+
2034
+ return buildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2035
+ }
2036
+
1903
2037
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
1904
2038
// Handle bitcasting from v2i8 without hitting the default promotion
1905
2039
// strategy which goes through stack memory.
@@ -2779,6 +2913,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2779
2913
return LowerVECTOR_SHUFFLE (Op, DAG);
2780
2914
case ISD::CONCAT_VECTORS:
2781
2915
return LowerCONCAT_VECTORS (Op, DAG);
2916
+ case ISD::VECREDUCE_FMAX:
2917
+ case ISD::VECREDUCE_FMIN:
2918
+ case ISD::VECREDUCE_FMAXIMUM:
2919
+ case ISD::VECREDUCE_FMINIMUM:
2920
+ return LowerVECREDUCE (Op, DAG);
2782
2921
case ISD::STORE:
2783
2922
return LowerSTORE (Op, DAG);
2784
2923
case ISD::LOAD:
0 commit comments