Skip to content

Commit 430e546

Browse files
committed
[NVPTX] lower VECREDUCE max/min to 3-input on sm_100+
Add support for 3-input fmaxnum/fminnum/fmaximum/fminimum introduced in PTX 8.8 for sm_100+: - Use a tree reduction when 3-input operations are supported and the reduction has the `reassoc`. - If not on sm_100+/PTX 8.8, fallback to 2-input operations and use the default shuffle reduction.
1 parent d4e8619 commit 430e546

File tree

5 files changed

+1098
-599
lines changed

5 files changed

+1098
-599
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
850850
if (STI.allowFP16Math() || STI.hasBF16Math())
851851
setTargetDAGCombine(ISD::SETCC);
852852

853+
// Vector reduction operations. These may be turned into shuffle or tree
854+
// reductions depending on what instructions are available for each type.
855+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
856+
MVT EltVT = VT.getVectorElementType();
857+
if (EltVT == MVT::f32 || EltVT == MVT::f64) {
858+
setOperationAction({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
859+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
860+
VT, Custom);
861+
}
862+
}
863+
853864
// Promote fp16 arithmetic if fp16 hardware isn't available or the
854865
// user passed --nvptx-no-fp16-math. The flag is useful because,
855866
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1093,6 +1104,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10931104
MAKE_CASE(NVPTXISD::BFI)
10941105
MAKE_CASE(NVPTXISD::PRMT)
10951106
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1107+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1108+
MAKE_CASE(NVPTXISD::FMINNUM3)
1109+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1110+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10961111
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10971112
MAKE_CASE(NVPTXISD::STACKRESTORE)
10981113
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -1900,6 +1915,125 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
19001915
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
19011916
}
19021917

1918+
/// Reduces the elements using the scalar operations provided. The operations
1919+
/// are sorted descending in number of inputs they take. The flags on the
1920+
/// original reduction operation will be propagated to each scalar operation.
1921+
/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
1922+
/// used in ExpandReductions and SelectionDAG.
1923+
static SDValue buildTreeReduction(
1924+
const SmallVector<SDValue> &Elements, EVT EltTy,
1925+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
1926+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
1927+
// Build the reduction tree at each level, starting with all the elements.
1928+
SmallVector<SDValue> Level = Elements;
1929+
1930+
unsigned OpIdx = 0;
1931+
while (Level.size() > 1) {
1932+
// Try to reduce this level using the current operator.
1933+
const auto [Op, NumInputs] = Ops[OpIdx];
1934+
1935+
// Build the next level by partially reducing all elements.
1936+
SmallVector<SDValue> ReducedLevel;
1937+
unsigned I = 0, E = Level.size();
1938+
for (; I + NumInputs <= E; I += NumInputs) {
1939+
// Reduce elements in groups of [NumInputs], as much as possible.
1940+
ReducedLevel.push_back(DAG.getNode(
1941+
Op, DL, EltTy, ArrayRef<SDValue>(Level).slice(I, NumInputs), Flags));
1942+
}
1943+
1944+
if (I < E) {
1945+
// Handle leftover elements.
1946+
1947+
if (ReducedLevel.empty()) {
1948+
// We didn't reduce anything at this level. We need to pick a smaller
1949+
// operator.
1950+
++OpIdx;
1951+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
1952+
continue;
1953+
}
1954+
1955+
// We reduced some things but there's still more left, meaning the
1956+
// operator's number of inputs doesn't evenly divide this level size. Move
1957+
// these elements to the next level.
1958+
for (; I < E; ++I)
1959+
ReducedLevel.push_back(Level[I]);
1960+
}
1961+
1962+
// Process the next level.
1963+
Level = ReducedLevel;
1964+
}
1965+
1966+
return *Level.begin();
1967+
}
1968+
1969+
// Get scalar reduction opcode
1970+
static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
1971+
switch (ReductionOpcode) {
1972+
case ISD::VECREDUCE_FMAX:
1973+
return ISD::FMAXNUM;
1974+
case ISD::VECREDUCE_FMIN:
1975+
return ISD::FMINNUM;
1976+
case ISD::VECREDUCE_FMAXIMUM:
1977+
return ISD::FMAXIMUM;
1978+
case ISD::VECREDUCE_FMINIMUM:
1979+
return ISD::FMINIMUM;
1980+
default:
1981+
llvm_unreachable("unhandled reduction opcode");
1982+
}
1983+
}
1984+
1985+
/// Get 3-input scalar reduction opcode
1986+
static std::optional<NVPTXISD::NodeType>
1987+
getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
1988+
switch (ReductionOpcode) {
1989+
case ISD::VECREDUCE_FMAX:
1990+
return NVPTXISD::FMAXNUM3;
1991+
case ISD::VECREDUCE_FMIN:
1992+
return NVPTXISD::FMINNUM3;
1993+
case ISD::VECREDUCE_FMAXIMUM:
1994+
return NVPTXISD::FMAXIMUM3;
1995+
case ISD::VECREDUCE_FMINIMUM:
1996+
return NVPTXISD::FMINIMUM3;
1997+
default:
1998+
return std::nullopt;
1999+
}
2000+
}
2001+
2002+
/// Lower reductions to either a sequence of operations or a tree if
2003+
/// reassociations are allowed. This method will use larger operations like
2004+
/// max3/min3 when the target supports them.
2005+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2006+
SelectionDAG &DAG) const {
2007+
SDLoc DL(Op);
2008+
const SDNodeFlags Flags = Op->getFlags();
2009+
SDValue Vector = Op.getOperand(0);
2010+
2011+
const unsigned Opcode = Op->getOpcode();
2012+
const EVT EltTy = Vector.getValueType().getVectorElementType();
2013+
2014+
// Whether we can use 3-input min/max when expanding the reduction.
2015+
const bool CanUseMinMax3 =
2016+
EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2017+
STI.getPTXVersion() >= 88 &&
2018+
(Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
2019+
Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);
2020+
2021+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2022+
// number of inputs they take.
2023+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2024+
2025+
if (auto Opcode3Elem = getScalar3OpcodeForReduction(Opcode);
2026+
CanUseMinMax3 && Opcode3Elem)
2027+
ScalarOps.push_back({*Opcode3Elem, 3});
2028+
ScalarOps.push_back({getScalarOpcodeForReduction(Opcode), 2});
2029+
2030+
// Otherwise, handle the reduction here.
2031+
SmallVector<SDValue> Elements;
2032+
DAG.ExtractVectorElements(Vector, Elements);
2033+
2034+
return buildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2035+
}
2036+
19032037
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
19042038
// Handle bitcasting from v2i8 without hitting the default promotion
19052039
// strategy which goes through stack memory.
@@ -2779,6 +2913,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27792913
return LowerVECTOR_SHUFFLE(Op, DAG);
27802914
case ISD::CONCAT_VECTORS:
27812915
return LowerCONCAT_VECTORS(Op, DAG);
2916+
case ISD::VECREDUCE_FMAX:
2917+
case ISD::VECREDUCE_FMIN:
2918+
case ISD::VECREDUCE_FMAXIMUM:
2919+
case ISD::VECREDUCE_FMINIMUM:
2920+
return LowerVECREDUCE(Op, DAG);
27822921
case ISD::STORE:
27832922
return LowerSTORE(Op, DAG);
27842923
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ enum NodeType : unsigned {
6464
UNPACK_VECTOR,
6565

6666
FCOPYSIGN,
67+
FMAXNUM3,
68+
FMINNUM3,
69+
FMAXIMUM3,
70+
FMINIMUM3,
71+
6772
DYNAMIC_STACKALLOC,
6873
STACKRESTORE,
6974
STACKSAVE,
@@ -286,6 +291,7 @@ class NVPTXTargetLowering : public TargetLowering {
286291

287292
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
288293
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
294+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
289295
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
290296
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
291297
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,36 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
356356
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
357357
}
358358

359+
// Template for 3-input minimum/maximum instructions
360+
// (sm_100+/PTX 8.8 and f32 only)
361+
//
362+
// Also defines ftz (flush subnormal inputs and results to sign-preserving
363+
// zero) variants for fp32 functions.
364+
multiclass FMINIMUMMAXIMUM3<string OpcStr, bit NaN, SDNode OpNode> {
365+
defvar nan_str = !if(NaN, ".NaN", "");
366+
def f32rrr :
367+
BasicFlagsNVPTXInst<(outs B32:$dst),
368+
(ins B32:$a, B32:$b, B32:$c),
369+
(ins FTZFlag:$ftz),
370+
OpcStr # "$ftz" # nan_str # ".f32",
371+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
372+
Requires<[hasPTX<88>, hasSM<100>]>;
373+
def f32rri :
374+
BasicFlagsNVPTXInst<(outs B32:$dst),
375+
(ins B32:$a, B32:$b, f32imm:$c),
376+
(ins FTZFlag:$ftz),
377+
OpcStr # "$ftz" # nan_str # ".f32",
378+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
379+
Requires<[hasPTX<88>, hasSM<100>]>;
380+
def f32rii :
381+
BasicFlagsNVPTXInst<(outs B32:$dst),
382+
(ins B32:$a, f32imm:$b, f32imm:$c),
383+
(ins FTZFlag:$ftz),
384+
OpcStr # "$ftz" # nan_str # ".f32",
385+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
386+
Requires<[hasPTX<88>, hasSM<100>]>;
387+
}
388+
359389
// Template for instructions which take three FP args. The
360390
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
361391
//
@@ -971,6 +1001,20 @@ defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
9711001
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
9721002
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;
9731003

1004+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1005+
[SDNPCommutative]>;
1006+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1007+
[SDNPCommutative]>;
1008+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1009+
[SDNPCommutative]>;
1010+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1011+
[SDNPCommutative]>;
1012+
1013+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ false, nvptx_fminnum3>;
1014+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", /* NaN */ false, nvptx_fmaxnum3>;
1015+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ true, nvptx_fminimum3>;
1016+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max", /* NaN */ true, nvptx_fmaximum3>;
1017+
9741018
defm FABS : F2<"abs", fabs>;
9751019
defm FNEG : F2<"neg", fneg>;
9761020
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
8787
}
8888
unsigned getMinVectorRegisterBitWidth() const override { return 32; }
8989

90+
bool shouldExpandReduction(const IntrinsicInst *II) const override {
91+
// Turn off ExpandReductions pass for NVPTX, which doesn't have advanced
92+
// swizzling operations. Our backend/Selection DAG can expand these
93+
// reductions with less movs.
94+
return false;
95+
}
96+
9097
// We don't want to prevent inlining because of target-cpu and -features
9198
// attributes that were added to newer versions of LLVM/Clang: There are
9299
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)