Skip to content

Commit 9a592d9

Browse files
authored
[NVPTX] lower VECREDUCE min/max to 3-input on sm_100+ (#136253)
Add support for 3-input fmaxnum/fminnum/fmaximum/fminimum introduced in PTX 8.8 for sm_100+: - Use a tree reduction when 3-input operations are supported and the reduction has the `reassoc` flag. - If not on sm_100+/PTX 8.8, fallback to 2-input operations and use the default shuffle reduction.
1 parent a04142f commit 9a592d9

File tree

5 files changed

+1097
-599
lines changed

5 files changed

+1097
-599
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
900900
if (STI.allowFP16Math() || STI.hasBF16Math())
901901
setTargetDAGCombine(ISD::SETCC);
902902

903+
// Vector reduction operations. These may be turned into shuffle or tree
904+
// reductions depending on what instructions are available for each type.
905+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
906+
MVT EltVT = VT.getVectorElementType();
907+
if (EltVT == MVT::f32 || EltVT == MVT::f64) {
908+
setOperationAction({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
909+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
910+
VT, Custom);
911+
}
912+
}
913+
903914
// Promote fp16 arithmetic if fp16 hardware isn't available or the
904915
// user passed --nvptx-no-fp16-math. The flag is useful because,
905916
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1143,6 +1154,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11431154
MAKE_CASE(NVPTXISD::BFI)
11441155
MAKE_CASE(NVPTXISD::PRMT)
11451156
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1157+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1158+
MAKE_CASE(NVPTXISD::FMINNUM3)
1159+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1160+
MAKE_CASE(NVPTXISD::FMINIMUM3)
11461161
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
11471162
MAKE_CASE(NVPTXISD::STACKRESTORE)
11481163
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -1929,6 +1944,124 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
19291944
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
19301945
}
19311946

1947+
/// Reduces the elements using the scalar operations provided. The operations
1948+
/// are sorted descending in number of inputs they take. The flags on the
1949+
/// original reduction operation will be propagated to each scalar operation.
1950+
/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1951+
/// used in ExpandReductions and SelectionDAG.
1952+
static SDValue buildTreeReduction(
1953+
const SmallVector<SDValue> &Elements, EVT EltTy,
1954+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1955+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1956+
// Build the reduction tree at each level, starting with all the elements.
1957+
SmallVector<SDValue> Level = Elements;
1958+
1959+
unsigned OpIdx = 0;
1960+
while (Level.size() > 1) {
1961+
// Try to reduce this level using the current operator.
1962+
const auto [Op, NumInputs] = Ops[OpIdx];
1963+
1964+
// Build the next level by partially reducing all elements.
1965+
SmallVector<SDValue> ReducedLevel;
1966+
unsigned I = 0, E = Level.size();
1967+
for (; I + NumInputs <= E; I += NumInputs) {
1968+
// Reduce elements in groups of [NumInputs], as much as possible.
1969+
ReducedLevel.push_back(DAG.getNode(
1970+
Op, DL, EltTy, ArrayRef<SDValue>(Level).slice(I, NumInputs), Flags));
1971+
}
1972+
1973+
if (I < E) {
1974+
// Handle leftover elements.
1975+
1976+
if (ReducedLevel.empty()) {
1977+
// We didn't reduce anything at this level. We need to pick a smaller
1978+
// operator.
1979+
++OpIdx;
1980+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1981+
continue;
1982+
}
1983+
1984+
// We reduced some things but there's still more left, meaning the
1985+
// operator's number of inputs doesn't evenly divide this level size. Move
1986+
// these elements to the next level.
1987+
for (; I < E; ++I)
1988+
ReducedLevel.push_back(Level[I]);
1989+
}
1990+
1991+
// Process the next level.
1992+
Level = ReducedLevel;
1993+
}
1994+
1995+
return *Level.begin();
1996+
}
1997+
1998+
// Get scalar reduction opcode
1999+
static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
2000+
switch (ReductionOpcode) {
2001+
case ISD::VECREDUCE_FMAX:
2002+
return ISD::FMAXNUM;
2003+
case ISD::VECREDUCE_FMIN:
2004+
return ISD::FMINNUM;
2005+
case ISD::VECREDUCE_FMAXIMUM:
2006+
return ISD::FMAXIMUM;
2007+
case ISD::VECREDUCE_FMINIMUM:
2008+
return ISD::FMINIMUM;
2009+
default:
2010+
llvm_unreachable("unhandled reduction opcode");
2011+
}
2012+
}
2013+
2014+
/// Get 3-input scalar reduction opcode
2015+
static std::optional<NVPTXISD::NodeType>
2016+
getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
2017+
switch (ReductionOpcode) {
2018+
case ISD::VECREDUCE_FMAX:
2019+
return NVPTXISD::FMAXNUM3;
2020+
case ISD::VECREDUCE_FMIN:
2021+
return NVPTXISD::FMINNUM3;
2022+
case ISD::VECREDUCE_FMAXIMUM:
2023+
return NVPTXISD::FMAXIMUM3;
2024+
case ISD::VECREDUCE_FMINIMUM:
2025+
return NVPTXISD::FMINIMUM3;
2026+
default:
2027+
return std::nullopt;
2028+
}
2029+
}
2030+
2031+
/// Lower reductions to either a sequence of operations or a tree if
2032+
/// reassociations are allowed. This method will use larger operations like
2033+
/// max3/min3 when the target supports them.
2034+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2035+
SelectionDAG &DAG) const {
2036+
SDLoc DL(Op);
2037+
const SDNodeFlags Flags = Op->getFlags();
2038+
SDValue Vector = Op.getOperand(0);
2039+
2040+
const unsigned Opcode = Op->getOpcode();
2041+
const EVT EltTy = Vector.getValueType().getVectorElementType();
2042+
2043+
// Whether we can use 3-input min/max when expanding the reduction.
2044+
const bool CanUseMinMax3 =
2045+
EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2046+
STI.getPTXVersion() >= 88 &&
2047+
(Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2048+
Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2049+
2050+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2051+
// number of inputs they take.
2052+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2053+
2054+
if (auto Opcode3Elem = getScalar3OpcodeForReduction(Opcode);
2055+
CanUseMinMax3 && Opcode3Elem)
2056+
ScalarOps.push_back({*Opcode3Elem, 3});
2057+
ScalarOps.push_back({getScalarOpcodeForReduction(Opcode), 2});
2058+
2059+
SmallVector<SDValue> Elements;
2060+
DAG.ExtractVectorElements(Vector, Elements);
2061+
2062+
return buildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2063+
}
2064+
19322065
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
19332066
// Handle bitcasting from v2i8 without hitting the default promotion
19342067
// strategy which goes through stack memory.
@@ -2808,6 +2941,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28082941
return LowerVECTOR_SHUFFLE(Op, DAG);
28092942
case ISD::CONCAT_VECTORS:
28102943
return LowerCONCAT_VECTORS(Op, DAG);
2944+
case ISD::VECREDUCE_FMAX:
2945+
case ISD::VECREDUCE_FMIN:
2946+
case ISD::VECREDUCE_FMAXIMUM:
2947+
case ISD::VECREDUCE_FMINIMUM:
2948+
return LowerVECREDUCE(Op, DAG);
28112949
case ISD::STORE:
28122950
return LowerSTORE(Op, DAG);
28132951
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ enum NodeType : unsigned {
6464
UNPACK_VECTOR,
6565

6666
FCOPYSIGN,
67+
FMAXNUM3,
68+
FMINNUM3,
69+
FMAXIMUM3,
70+
FMINIMUM3,
71+
6772
DYNAMIC_STACKALLOC,
6873
STACKRESTORE,
6974
STACKSAVE,
@@ -286,6 +291,7 @@ class NVPTXTargetLowering : public TargetLowering {
286291

287292
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
288293
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
294+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
289295
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
290296
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
291297
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,36 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
347347
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
348348
}
349349

350+
// Template for 3-input minimum/maximum instructions
351+
// (sm_100+/PTX 8.8 and f32 only)
352+
//
353+
// Also defines ftz (flush subnormal inputs and results to sign-preserving
354+
// zero) variants for fp32 functions.
355+
multiclass FMINIMUMMAXIMUM3<string OpcStr, bit NaN, SDNode OpNode> {
356+
defvar nan_str = !if(NaN, ".NaN", "");
357+
def f32rrr :
358+
BasicFlagsNVPTXInst<(outs B32:$dst),
359+
(ins B32:$a, B32:$b, B32:$c),
360+
(ins FTZFlag:$ftz),
361+
OpcStr # "$ftz" # nan_str # ".f32",
362+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
363+
Requires<[hasPTX<88>, hasSM<100>]>;
364+
def f32rri :
365+
BasicFlagsNVPTXInst<(outs B32:$dst),
366+
(ins B32:$a, B32:$b, f32imm:$c),
367+
(ins FTZFlag:$ftz),
368+
OpcStr # "$ftz" # nan_str # ".f32",
369+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
370+
Requires<[hasPTX<88>, hasSM<100>]>;
371+
def f32rii :
372+
BasicFlagsNVPTXInst<(outs B32:$dst),
373+
(ins B32:$a, f32imm:$b, f32imm:$c),
374+
(ins FTZFlag:$ftz),
375+
OpcStr # "$ftz" # nan_str # ".f32",
376+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
377+
Requires<[hasPTX<88>, hasSM<100>]>;
378+
}
379+
350380
// Template for instructions which take three FP args. The
351381
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
352382
//
@@ -900,6 +930,20 @@ defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
900930
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
901931
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;
902932

933+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
934+
[SDNPCommutative]>;
935+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
936+
[SDNPCommutative]>;
937+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
938+
[SDNPCommutative]>;
939+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
940+
[SDNPCommutative]>;
941+
942+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ false, nvptx_fminnum3>;
943+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", /* NaN */ false, nvptx_fmaxnum3>;
944+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ true, nvptx_fminimum3>;
945+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max", /* NaN */ true, nvptx_fmaximum3>;
946+
903947
defm FABS : F2<"abs", fabs>;
904948
defm FNEG : F2<"neg", fneg>;
905949
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
8787
}
8888
unsigned getMinVectorRegisterBitWidth() const override { return 32; }
8989

90+
bool shouldExpandReduction(const IntrinsicInst *II) const override {
91+
// Turn off ExpandReductions pass for NVPTX, which doesn't have advanced
92+
// swizzling operations. Our backend/Selection DAG can expand these
93+
// reductions with less movs.
94+
return false;
95+
}
96+
9097
// We don't want to prevent inlining because of target-cpu and -features
9198
// attributes that were added to newer versions of LLVM/Clang: There are
9299
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)