Skip to content

Commit f4f7fe1

Browse files
committed
[NVPTX] lower VECREDUCE max/min to 3-input on sm_100+
Add support for 3-input fmaxnum/fminnum/fmaximum/fminimum introduced in PTX 8.8 for sm_100+. Other improvements are: - Fix lowering of fmaxnum/fminnum so that they are not reassociated. According to the IEEE 754 definition of these (maxNum/minNum), they are not associative due to how they handle sNaNs. A quick example: a = 1.0, b = 1.0, c = sNaN maxNum(a, maxNum(b, c)) = maxNum(a, qNaN) = a = 1.0 maxNum(maxNum(a, b), c) = maxNum(1.0, sNaN) = qNaN - Use a tree reduction when 3-input operations are supported and the reduction has the `reassoc`. - If not on sm_100+/PTX 8.8, fallback to 2-input operations and use the default shuffle reduction.
1 parent 255bba0 commit f4f7fe1

File tree

5 files changed

+1222
-641
lines changed

5 files changed

+1222
-641
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
850850
if (STI.allowFP16Math() || STI.hasBF16Math())
851851
setTargetDAGCombine(ISD::SETCC);
852852

853+
// Vector reduction operations. These may be turned into sequential, shuffle,
854+
// or tree reductions depending on what instructions are available for each
855+
// type.
856+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
857+
MVT EltVT = VT.getVectorElementType();
858+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
859+
EltVT == MVT::f64) {
860+
setOperationAction({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
861+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
862+
VT, Custom);
863+
}
864+
}
865+
853866
// Promote fp16 arithmetic if fp16 hardware isn't available or the
854867
// user passed --nvptx-no-fp16-math. The flag is useful because,
855868
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1093,6 +1106,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10931106
MAKE_CASE(NVPTXISD::BFI)
10941107
MAKE_CASE(NVPTXISD::PRMT)
10951108
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1109+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1110+
MAKE_CASE(NVPTXISD::FMINNUM3)
1111+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1112+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10961113
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10971114
MAKE_CASE(NVPTXISD::STACKRESTORE)
10981115
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -1900,6 +1917,191 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
19001917
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
19011918
}
19021919

1920+
/// A generic routine for constructing a tree reduction on a vector operand.
1921+
/// This method groups elements bottom-up, progressively building each level.
1922+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
1923+
/// adjacent elements are combined first, leading to shorter live ranges. This
1924+
/// approach makes the most sense if the shuffle reduction would use the same
1925+
/// amount of registers.
1926+
///
1927+
/// The flags on the original reduction operation will be propagated to
1928+
/// each scalar operation.
1929+
static SDValue BuildTreeReduction(
1930+
const SmallVector<SDValue> &Elements, EVT EltTy,
1931+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1932+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1933+
// Build the reduction tree at each level, starting with all the elements.
1934+
SmallVector<SDValue> Level = Elements;
1935+
1936+
unsigned OpIdx = 0;
1937+
while (Level.size() > 1) {
1938+
// Try to reduce this level using the current operator.
1939+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
1940+
1941+
// Build the next level by partially reducing all elements.
1942+
SmallVector<SDValue> ReducedLevel;
1943+
unsigned I = 0, E = Level.size();
1944+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
1945+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
1946+
ReducedLevel.push_back(DAG.getNode(
1947+
DefaultScalarOp, DL, EltTy,
1948+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
1949+
}
1950+
1951+
if (I < E) {
1952+
// Handle leftover elements.
1953+
1954+
if (ReducedLevel.empty()) {
1955+
// We didn't reduce anything at this level. We need to pick a smaller
1956+
// operator.
1957+
++OpIdx;
1958+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1959+
continue;
1960+
}
1961+
1962+
// We reduced some things but there's still more left, meaning the
1963+
// operator's number of inputs doesn't evenly divide this level size. Move
1964+
// these elements to the next level.
1965+
for (; I < E; ++I)
1966+
ReducedLevel.push_back(Level[I]);
1967+
}
1968+
1969+
// Process the next level.
1970+
Level = ReducedLevel;
1971+
}
1972+
1973+
return *Level.begin();
1974+
}
1975+
1976+
/// Lower reductions to either a sequence of operations or a tree if
1977+
/// reassociations are allowed. This method will use larger operations like
1978+
/// max3/min3 when the target supports them.
1979+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
1980+
SelectionDAG &DAG) const {
1981+
SDLoc DL(Op);
1982+
const SDNodeFlags Flags = Op->getFlags();
1983+
SDValue Vector = Op.getOperand(0);
1984+
SDValue Accumulator;
1985+
1986+
EVT EltTy = Vector.getValueType().getVectorElementType();
1987+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
1988+
STI.getPTXVersion() >= 88;
1989+
1990+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
1991+
// number of inputs they take.
1992+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
1993+
1994+
// Whether we can lower to scalar operations in an arbitrary order.
1995+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
1996+
1997+
// Whether the data type and operation can be represented with fewer ops and
1998+
// registers in a shuffle reduction.
1999+
bool PrefersShuffle;
2000+
2001+
switch (Op->getOpcode()) {
2002+
case ISD::VECREDUCE_FMAX:
2003+
if (CanUseMinMax3) {
2004+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2005+
// Can't use fmaxnum3 in shuffle reduction
2006+
PrefersShuffle = false;
2007+
} else {
2008+
// Prefer max.{,b}f16x2 for v2{,b}f16
2009+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2010+
}
2011+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2012+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2013+
// of sNaN inputs.
2014+
IsAssociative = Flags.hasNoNaNs();
2015+
break;
2016+
case ISD::VECREDUCE_FMIN:
2017+
if (CanUseMinMax3) {
2018+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2019+
// Can't use fminnum3 in shuffle reduction
2020+
PrefersShuffle = false;
2021+
} else {
2022+
// Prefer min.{,b}f16x2 for v2{,b}f16
2023+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2024+
}
2025+
ScalarOps.push_back({ISD::FMINNUM, 2});
2026+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2027+
// of sNaN inputs.
2028+
IsAssociative = Flags.hasNoNaNs();
2029+
break;
2030+
case ISD::VECREDUCE_FMAXIMUM:
2031+
if (CanUseMinMax3) {
2032+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2033+
// Can't use fmax3 in shuffle reduction
2034+
PrefersShuffle = false;
2035+
} else {
2036+
// Prefer max.{,b}f16x2 for v2{,b}f16
2037+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2038+
}
2039+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2040+
IsAssociative = true;
2041+
break;
2042+
case ISD::VECREDUCE_FMINIMUM:
2043+
if (CanUseMinMax3) {
2044+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2045+
// Can't use fmin3 in shuffle reduction
2046+
PrefersShuffle = false;
2047+
} else {
2048+
// Prefer min.{,b}f16x2 for v2{,b}f16
2049+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2050+
}
2051+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2052+
IsAssociative = true;
2053+
break;
2054+
default:
2055+
llvm_unreachable("unhandled vecreduce operation");
2056+
}
2057+
2058+
// We don't expect an accumulator for reassociative vector reduction ops.
2059+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2060+
2061+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2062+
if (IsAssociative && PrefersShuffle)
2063+
return SDValue();
2064+
2065+
// Otherwise, handle the reduction here.
2066+
SmallVector<SDValue> Elements;
2067+
DAG.ExtractVectorElements(Vector, Elements);
2068+
2069+
// Lower to tree reduction.
2070+
if (IsAssociative)
2071+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2072+
2073+
// Lower to sequential reduction.
2074+
EVT VectorTy = Vector.getValueType();
2075+
const unsigned NumElts = VectorTy.getVectorNumElements();
2076+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2077+
// Try to reduce the remaining sequence as much as possible using the
2078+
// current operator.
2079+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2080+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2081+
2082+
if (!Accumulator) {
2083+
// Try to initialize the accumulator using the current operator.
2084+
if (I + DefaultGroupSize <= NumElts) {
2085+
Accumulator = DAG.getNode(
2086+
DefaultScalarOp, DL, EltTy,
2087+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2088+
I += DefaultGroupSize;
2089+
}
2090+
}
2091+
2092+
if (Accumulator) {
2093+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2094+
SmallVector<SDValue> Operands = {Accumulator};
2095+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2096+
Operands.push_back(Elements[I + K]);
2097+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2098+
}
2099+
}
2100+
}
2101+
2102+
return Accumulator;
2103+
}
2104+
19032105
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
19042106
// Handle bitcasting from v2i8 without hitting the default promotion
19052107
// strategy which goes through stack memory.
@@ -2779,6 +2981,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27792981
return LowerVECTOR_SHUFFLE(Op, DAG);
27802982
case ISD::CONCAT_VECTORS:
27812983
return LowerCONCAT_VECTORS(Op, DAG);
2984+
case ISD::VECREDUCE_FMAX:
2985+
case ISD::VECREDUCE_FMIN:
2986+
case ISD::VECREDUCE_FMAXIMUM:
2987+
case ISD::VECREDUCE_FMINIMUM:
2988+
return LowerVECREDUCE(Op, DAG);
27822989
case ISD::STORE:
27832990
return LowerSTORE(Op, DAG);
27842991
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ enum NodeType : unsigned {
6464
UNPACK_VECTOR,
6565

6666
FCOPYSIGN,
67+
FMAXNUM3,
68+
FMINNUM3,
69+
FMAXIMUM3,
70+
FMINIMUM3,
71+
6772
DYNAMIC_STACKALLOC,
6873
STACKRESTORE,
6974
STACKSAVE,
@@ -286,6 +291,7 @@ class NVPTXTargetLowering : public TargetLowering {
286291

287292
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
288293
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
294+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
289295
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
290296
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
291297
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,36 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
358358
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
359359
}
360360

361+
// Template for 3-input minimum/maximum instructions
362+
// (sm_100+/PTX 8.8 and f32 only)
363+
//
364+
// Also defines ftz (flush subnormal inputs and results to sign-preserving
365+
// zero) variants for fp32 functions.
366+
multiclass FMINIMUMMAXIMUM3<string OpcStr, bit NaN, SDNode OpNode> {
367+
defvar nan_str = !if(NaN, ".NaN", "");
368+
def f32rrr :
369+
BasicFlagsNVPTXInst<(outs B32:$dst),
370+
(ins B32:$a, B32:$b, B32:$c),
371+
(ins FTZFlag:$ftz),
372+
OpcStr # "$ftz" # nan_str # ".f32",
373+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
374+
Requires<[hasPTX<88>, hasSM<100>]>;
375+
def f32rri :
376+
BasicFlagsNVPTXInst<(outs B32:$dst),
377+
(ins B32:$a, B32:$b, f32imm:$c),
378+
(ins FTZFlag:$ftz),
379+
OpcStr # "$ftz" # nan_str # ".f32",
380+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
381+
Requires<[hasPTX<88>, hasSM<100>]>;
382+
def f32rii :
383+
BasicFlagsNVPTXInst<(outs B32:$dst),
384+
(ins B32:$a, f32imm:$b, f32imm:$c),
385+
(ins FTZFlag:$ftz),
386+
OpcStr # "$ftz" # nan_str # ".f32",
387+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
388+
Requires<[hasPTX<88>, hasSM<100>]>;
389+
}
390+
361391
// Template for instructions which take three FP args. The
362392
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
363393
//
@@ -1027,6 +1057,20 @@ defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
10271057
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
10281058
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;
10291059

1060+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1061+
[SDNPCommutative]>;
1062+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1063+
[SDNPCommutative]>;
1064+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1065+
[SDNPCommutative]>;
1066+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1067+
[SDNPCommutative]>;
1068+
1069+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ false, nvptx_fminnum3>;
1070+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", /* NaN */ false, nvptx_fmaxnum3>;
1071+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ true, nvptx_fminimum3>;
1072+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max", /* NaN */ true, nvptx_fmaximum3>;
1073+
10301074
defm FABS : F2<"abs", fabs>;
10311075
defm FNEG : F2<"neg", fneg>;
10321076
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
8787
}
8888
unsigned getMinVectorRegisterBitWidth() const override { return 32; }
8989

90+
bool shouldExpandReduction(const IntrinsicInst *II) const override {
91+
// Turn off ExpandReductions pass for NVPTX, which doesn't have advanced
92+
// swizzling operations. Our backend/Selection DAG can expand these
93+
// reductions with less movs.
94+
return false;
95+
}
96+
9097
// We don't want to prevent inlining because of target-cpu and -features
9198
// attributes that were added to newer versions of LLVM/Clang: There are
9299
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)