Skip to content

Commit f07cedb

Browse files
[NVPTX] Reland mad.wide combine under (default off) CLI option (#160214)
Users reported regressions to important matmul kernels as a result of #155024. Although #155024 was a revert, this PR should allow them to recover some of the lost performance.
1 parent 50a7eb6 commit f07cedb

File tree

4 files changed

+369
-215
lines changed

4 files changed

+369
-215
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ static cl::opt<bool>
3838
EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
3939
cl::desc("Enable reciprocal sqrt optimization"));
4040

41+
// FIXME: This is a WAR to recover lost performance from #155024.
42+
// We still need to investigate the regression and find a more permanent
43+
// solution.
44+
static cl::opt<bool> EnableMADWide("nvptx-mad-wide-opt", cl::init(false),
45+
cl::Hidden,
46+
cl::desc("Enable MAD wide optimization"));
47+
4148
/// createNVPTXISelDag - This pass converts a legalized DAG into a
4249
/// NVPTX-specific DAG, ready for instruction scheduling.
4350
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -84,6 +91,8 @@ bool NVPTXDAGToDAGISel::allowFMA() const {
8491

8592
bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
8693

94+
bool NVPTXDAGToDAGISel::doMADWideOpt() const { return EnableMADWide; }
95+
8796
/// Select - Select instructions not customized! Used for
8897
/// expanded, promoted and normal instructions.
8998
void NVPTXDAGToDAGISel::Select(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4545
bool useF32FTZ() const;
4646
bool allowFMA() const;
4747
bool doRsqrtOpt() const;
48+
bool doMADWideOpt() const;
4849

4950
NVPTXScopes Scopes{};
5051

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def hasArchAccelFeatures : Predicate<"Subtarget->hasArchAccelFeatures()">;
114114
def doF32FTZ : Predicate<"useF32FTZ()">;
115115
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
116116
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
117+
def doMADWideOpt : Predicate<"doMADWideOpt()">;
117118

118119
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
119120
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
@@ -900,8 +901,15 @@ let Predicates = [hasOptEnabled] in {
900901
defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>;
901902
defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>;
902903

903-
// Generating mad.wide causes a regression:
904+
// Generating mad.wide causes a regression in some cases:
904905
// https://github.com/llvm/llvm-project/pull/150477#issuecomment-3191367837
906+
// Only do so when the user requests it.
907+
let Predicates = [doMADWideOpt] in {
908+
defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>;
909+
defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>;
910+
defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>;
911+
defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>;
912+
}
905913
}
906914

907915
//-----------------------------------

0 commit comments

Comments
 (0)