Skip to content

Conversation

@justinfargnoli
Copy link
Contributor

@justinfargnoli justinfargnoli commented Sep 22, 2025

Users reported regressions to important matmul kernels as a result of #155024. Although #155024 was a revert, this PR should allow them to recover some of the lost performance.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces MAD (Multiply-Add) wide optimization support for the NVPTX target in LLVM. The change adds a new command-line option to enable MAD wide operations and implements the corresponding instruction patterns.

  • Adds a new command-line flag nvptx-mad-wide-opt to control MAD wide optimization
  • Implements MAD wide instruction definitions for 16-bit and 32-bit signed/unsigned operations
  • Provides infrastructure to conditionally enable MAD wide optimizations through predicates

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
NVPTXInstrInfo.td Adds predicate and instruction definitions for MAD wide operations
NVPTXISelDAGToDAG.h Declares the doMADWideOpt() method
NVPTXISelDAGToDAG.cpp Implements command-line option and doMADWideOpt() method

@github-actions
Copy link

github-actions bot commented Sep 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@justinfargnoli justinfargnoli changed the title Initial commit [NVPTX] Reland mad.wide combine under (default off) CLI option Sep 24, 2025
@justinfargnoli justinfargnoli self-assigned this Sep 24, 2025
@justinfargnoli justinfargnoli marked this pull request as ready for review September 24, 2025 05:57
@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Justin Fargnoli (justinfargnoli)

Changes

Follow-up to #155024 based on reported regressions to important matmul kernels.


Patch is 27.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160214.diff

4 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+6)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+9-1)
  • (modified) llvm/test/CodeGen/NVPTX/combine-wide.ll (+350-214)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index c70f48af33cf2..b7de0a4554cd3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -38,6 +38,10 @@ static cl::opt<bool>
     EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
                    cl::desc("Enable reciprocal sqrt optimization"));
 
+static cl::opt<bool> EnableMADWide("nvptx-mad-wide-opt", cl::init(false),
+                                   cl::Hidden,
+                                   cl::desc("Enable MAD wide optimization"));
+
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -84,6 +88,8 @@ bool NVPTXDAGToDAGISel::allowFMA() const {
 
 bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
 
+bool NVPTXDAGToDAGISel::doMADWideOpt() const { return EnableMADWide; }
+
 /// Select - Select instructions not customized! Used for
 /// expanded, promoted and normal instructions.
 void NVPTXDAGToDAGISel::Select(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8dcd5362c4512..c912e709d0aa0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -45,6 +45,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool useF32FTZ() const;
   bool allowFMA() const;
   bool doRsqrtOpt() const;
+  bool doMADWideOpt() const;
 
   NVPTXScopes Scopes{};
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4e38e026e6bda..4e873558b2537 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -114,6 +114,7 @@ def hasArchAccelFeatures : Predicate<"Subtarget->hasArchAccelFeatures()">;
 def doF32FTZ : Predicate<"useF32FTZ()">;
 def doNoF32FTZ : Predicate<"!useF32FTZ()">;
 def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
+def doMADWideOpt : Predicate<"doMADWideOpt()">;
 
 def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
 def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
@@ -899,8 +900,15 @@ let Predicates = [hasOptEnabled] in {
   defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>;
   defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>;
 
-  // Generating mad.wide causes a regression: 
+  // Generating mad.wide causes a regression in some cases: 
   // https://github.com/llvm/llvm-project/pull/150477#issuecomment-3191367837
+  // Only do so when the user requests it.
+  let Predicates = [doMADWideOpt] in {
+    defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>;
+    defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>;
+    defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>;
+    defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>;
+  }
 }
 
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/combine-wide.ll b/llvm/test/CodeGen/NVPTX/combine-wide.ll
index b5948d37c3505..63e0f3789f49f 100644
--- a/llvm/test/CodeGen/NVPTX/combine-wide.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-wide.ll
@@ -1,24 +1,37 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -O1 | FileCheck %s --check-prefixes=CHECK,O1
+; RUN: llc < %s -O1 | FileCheck %s --check-prefixes=CHECK,O1,O1-NO-MAD
+; RUN: llc < %s -O1 -nvptx-mad-wide-opt | FileCheck %s --check-prefixes=CHECK,O1,O1-MAD
 ; RUN: llc < %s -O0 | FileCheck %s --check-prefixes=CHECK,O0
 
 target triple = "nvptx64-nvidia-cuda"
 
 define i64 @t1(i32 %a, i32 %b, i64 %c) {
-;
-; O1-LABEL: t1(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<3>;
-; O1-NEXT:    .reg .b64 %rd<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t1_param_0];
-; O1-NEXT:    ld.param.b32 %r2, [t1_param_1];
-; O1-NEXT:    mul.wide.s32 %rd1, %r1, %r2;
-; O1-NEXT:    ld.param.b64 %rd2, [t1_param_2];
-; O1-NEXT:    add.s64 %rd3, %rd2, %rd1;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t1(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t1_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t1_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.s32 %rd1, %r1, %r2;
+; O1-NO-MAD-NEXT:    ld.param.b64 %rd2, [t1_param_2];
+; O1-NO-MAD-NEXT:    add.s64 %rd3, %rd2, %rd1;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t1(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t1_param_0];
+; O1-MAD-NEXT:    ld.param.b32 %r2, [t1_param_1];
+; O1-MAD-NEXT:    ld.param.b64 %rd1, [t1_param_2];
+; O1-MAD-NEXT:    mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t1(
 ; O0:       {
@@ -41,20 +54,32 @@ define i64 @t1(i32 %a, i32 %b, i64 %c) {
 }
 
 define i64 @t2(i32 %a, i32 %b, i64 %c) {
-;
-; O1-LABEL: t2(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<3>;
-; O1-NEXT:    .reg .b64 %rd<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t2_param_0];
-; O1-NEXT:    ld.param.b32 %r2, [t2_param_1];
-; O1-NEXT:    mul.wide.s32 %rd1, %r1, %r2;
-; O1-NEXT:    ld.param.b64 %rd2, [t2_param_2];
-; O1-NEXT:    add.s64 %rd3, %rd1, %rd2;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t2(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t2_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t2_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.s32 %rd1, %r1, %r2;
+; O1-NO-MAD-NEXT:    ld.param.b64 %rd2, [t2_param_2];
+; O1-NO-MAD-NEXT:    add.s64 %rd3, %rd1, %rd2;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t2(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t2_param_0];
+; O1-MAD-NEXT:    ld.param.b32 %r2, [t2_param_1];
+; O1-MAD-NEXT:    ld.param.b64 %rd1, [t2_param_2];
+; O1-MAD-NEXT:    mad.wide.s32 %rd2, %r1, %r2, %rd1;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t2(
 ; O0:       {
@@ -77,19 +102,30 @@ define i64 @t2(i32 %a, i32 %b, i64 %c) {
 }
 
 define i64 @t3(i32 %a, i32 %b) {
-;
-; O1-LABEL: t3(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<3>;
-; O1-NEXT:    .reg .b64 %rd<3>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t3_param_0];
-; O1-NEXT:    ld.param.b32 %r2, [t3_param_1];
-; O1-NEXT:    mul.wide.s32 %rd1, %r1, %r2;
-; O1-NEXT:    add.s64 %rd2, %rd1, 1;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd2;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t3(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t3_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t3_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.s32 %rd1, %r1, %r2;
+; O1-NO-MAD-NEXT:    add.s64 %rd2, %rd1, 1;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t3(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-MAD-NEXT:    .reg .b64 %rd<2>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t3_param_0];
+; O1-MAD-NEXT:    ld.param.b32 %r2, [t3_param_1];
+; O1-MAD-NEXT:    mad.wide.s32 %rd1, %r1, %r2, 1;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd1;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t3(
 ; O0:       {
@@ -111,19 +147,30 @@ define i64 @t3(i32 %a, i32 %b) {
 }
 
 define i64 @t4(i32 %a, i64 %c) {
-;
-; O1-LABEL: t4(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<2>;
-; O1-NEXT:    .reg .b64 %rd<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t4_param_0];
-; O1-NEXT:    ld.param.b64 %rd1, [t4_param_1];
-; O1-NEXT:    mul.wide.s32 %rd2, %r1, 3;
-; O1-NEXT:    add.s64 %rd3, %rd1, %rd2;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t4(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<2>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t4_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b64 %rd1, [t4_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.s32 %rd2, %r1, 3;
+; O1-NO-MAD-NEXT:    add.s64 %rd3, %rd1, %rd2;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t4(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<2>;
+; O1-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t4_param_0];
+; O1-MAD-NEXT:    ld.param.b64 %rd1, [t4_param_1];
+; O1-MAD-NEXT:    mad.wide.s32 %rd2, %r1, 3, %rd1;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t4(
 ; O0:       {
@@ -145,18 +192,28 @@ define i64 @t4(i32 %a, i64 %c) {
 }
 
 define i64 @t4_1(i32 %a, i64 %c) {
-;
-; O1-LABEL: t4_1(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<2>;
-; O1-NEXT:    .reg .b64 %rd<3>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t4_1_param_0];
-; O1-NEXT:    mul.wide.s32 %rd1, %r1, 3;
-; O1-NEXT:    add.s64 %rd2, %rd1, 5;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd2;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t4_1(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<2>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t4_1_param_0];
+; O1-NO-MAD-NEXT:    mul.wide.s32 %rd1, %r1, 3;
+; O1-NO-MAD-NEXT:    add.s64 %rd2, %rd1, 5;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t4_1(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<2>;
+; O1-MAD-NEXT:    .reg .b64 %rd<2>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t4_1_param_0];
+; O1-MAD-NEXT:    mad.wide.s32 %rd1, %r1, 3, 5;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd1;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t4_1(
 ; O0:       {
@@ -177,20 +234,32 @@ define i64 @t4_1(i32 %a, i64 %c) {
 }
 
 define i64 @t5(i32 %a, i32 %b, i64 %c) {
-;
-; O1-LABEL: t5(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<3>;
-; O1-NEXT:    .reg .b64 %rd<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t5_param_0];
-; O1-NEXT:    ld.param.b32 %r2, [t5_param_1];
-; O1-NEXT:    mul.wide.u32 %rd1, %r1, %r2;
-; O1-NEXT:    ld.param.b64 %rd2, [t5_param_2];
-; O1-NEXT:    add.s64 %rd3, %rd2, %rd1;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t5(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t5_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t5_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.u32 %rd1, %r1, %r2;
+; O1-NO-MAD-NEXT:    ld.param.b64 %rd2, [t5_param_2];
+; O1-NO-MAD-NEXT:    add.s64 %rd3, %rd2, %rd1;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t5(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t5_param_0];
+; O1-MAD-NEXT:    ld.param.b32 %r2, [t5_param_1];
+; O1-MAD-NEXT:    ld.param.b64 %rd1, [t5_param_2];
+; O1-MAD-NEXT:    mad.wide.u32 %rd2, %r1, %r2, %rd1;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t5(
 ; O0:       {
@@ -213,20 +282,32 @@ define i64 @t5(i32 %a, i32 %b, i64 %c) {
 }
 
 define i64 @t6(i32 %a, i32 %b, i64 %c) {
-;
-; O1-LABEL: t6(
-; O1:       {
-; O1-NEXT:    .reg .b32 %r<3>;
-; O1-NEXT:    .reg .b64 %rd<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b32 %r1, [t6_param_0];
-; O1-NEXT:    ld.param.b32 %r2, [t6_param_1];
-; O1-NEXT:    mul.wide.u32 %rd1, %r1, %r2;
-; O1-NEXT:    ld.param.b64 %rd2, [t6_param_2];
-; O1-NEXT:    add.s64 %rd3, %rd1, %rd2;
-; O1-NEXT:    st.param.b64 [func_retval0], %rd3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t6(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-NO-MAD-NEXT:    .reg .b64 %rd<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b32 %r1, [t6_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t6_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.u32 %rd1, %r1, %r2;
+; O1-NO-MAD-NEXT:    ld.param.b64 %rd2, [t6_param_2];
+; O1-NO-MAD-NEXT:    add.s64 %rd3, %rd1, %rd2;
+; O1-NO-MAD-NEXT:    st.param.b64 [func_retval0], %rd3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t6(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-MAD-NEXT:    .reg .b64 %rd<3>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t6_param_0];
+; O1-MAD-NEXT:    ld.param.b32 %r2, [t6_param_1];
+; O1-MAD-NEXT:    ld.param.b64 %rd1, [t6_param_2];
+; O1-MAD-NEXT:    mad.wide.u32 %rd2, %r1, %r2, %rd1;
+; O1-MAD-NEXT:    st.param.b64 [func_retval0], %rd2;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t6(
 ; O0:       {
@@ -249,7 +330,6 @@ define i64 @t6(i32 %a, i32 %b, i64 %c) {
 }
 
 define i32 @t7(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t7(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<4>;
@@ -281,7 +361,6 @@ define i32 @t7(i16 %a, i16 %b) {
 }
 
 define i32 @t8(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t8(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<4>;
@@ -313,7 +392,6 @@ define i32 @t8(i16 %a, i16 %b) {
 }
 
 define i64 @t9(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t9(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<4>;
@@ -345,7 +423,6 @@ define i64 @t9(i32 %a, i32 %b) {
 }
 
 define i64 @t10(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t10(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<4>;
@@ -377,7 +454,6 @@ define i64 @t10(i32 %a, i32 %b) {
 }
 
 define i32 @t11(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t11(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<4>;
@@ -409,7 +485,6 @@ define i32 @t11(i16 %a, i16 %b) {
 }
 
 define i32 @t12(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t12(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<3>;
@@ -440,7 +515,6 @@ define i32 @t12(i16 %a, i16 %b) {
 }
 
 define i64 @t13(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t13(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<4>;
@@ -472,7 +546,6 @@ define i64 @t13(i32 %a, i32 %b) {
 }
 
 define i64 @t14(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t14(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<3>;
@@ -503,7 +576,6 @@ define i64 @t14(i32 %a, i32 %b) {
 }
 
 define i32 @t15(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t15(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<3>;
@@ -534,7 +606,6 @@ define i32 @t15(i16 %a, i16 %b) {
 }
 
 define i32 @t16(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t16(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<4>;
@@ -566,7 +637,6 @@ define i32 @t16(i16 %a, i16 %b) {
 }
 
 define i64 @t17(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t17(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<3>;
@@ -597,7 +667,6 @@ define i64 @t17(i32 %a, i32 %b) {
 }
 
 define i64 @t18(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t18(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<4>;
@@ -629,7 +698,6 @@ define i64 @t18(i32 %a, i32 %b) {
 }
 
 define i32 @t19(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t19(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<4>;
@@ -661,7 +729,6 @@ define i32 @t19(i16 %a, i16 %b) {
 }
 
 define i32 @t20(i16 %a) {
-;
 ; CHECK-LABEL: t20(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<3>;
@@ -679,7 +746,6 @@ define i32 @t20(i16 %a) {
 }
 
 define i64 @t21(i32 %a) {
-;
 ; CHECK-LABEL: t21(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
@@ -697,7 +763,6 @@ define i64 @t21(i32 %a) {
 }
 
 define i64 @t22(i32 %a) {
-;
 ; CHECK-LABEL: t22(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
@@ -715,7 +780,6 @@ define i64 @t22(i32 %a) {
 }
 
 define i32 @t23(i16 %a, i16 %b) {
-;
 ; CHECK-LABEL: t23(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<3>;
@@ -733,7 +797,6 @@ define i32 @t23(i16 %a, i16 %b) {
 }
 
 define i32 @t24(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t24(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<2>;
@@ -762,7 +825,6 @@ define i32 @t24(i16 %a, i16 %b) {
 }
 
 define i64 @t25(i32 %a) {
-;
 ; CHECK-LABEL: t25(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
@@ -780,7 +842,6 @@ define i64 @t25(i32 %a) {
 }
 
 define i64 @t26(i32 %a) {
-;
 ; O1-LABEL: t26(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<2>;
@@ -809,7 +870,6 @@ define i64 @t26(i32 %a) {
 }
 
 define i32 @t27(i16 %a, i16 %b) {
-;
 ; O1-LABEL: t27(
 ; O1:       {
 ; O1-NEXT:    .reg .b16 %rs<2>;
@@ -838,7 +898,6 @@ define i32 @t27(i16 %a, i16 %b) {
 }
 
 define i32 @t28(i16 %a, i16 %b) {
-;
 ; CHECK-LABEL: t28(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<3>;
@@ -856,7 +915,6 @@ define i32 @t28(i16 %a, i16 %b) {
 }
 
 define i64 @t29(i32 %a) {
-;
 ; O1-LABEL: t29(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<2>;
@@ -885,7 +943,6 @@ define i64 @t29(i32 %a) {
 }
 
 define i64 @t30(i32 %a) {
-;
 ; CHECK-LABEL: t30(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<3>;
@@ -903,7 +960,6 @@ define i64 @t30(i32 %a) {
 }
 
 define i64 @t31(i32 %a, i32 %b) {
-;
 ; O1-LABEL: t31(
 ; O1:       {
 ; O1-NEXT:    .reg .b32 %r<4>;
@@ -935,20 +991,32 @@ define i64 @t31(i32 %a, i32 %b) {
 }
 
 define i32 @t32(i16 %a, i16 %b, i32 %c) {
-;
-; O1-LABEL: t32(
-; O1:       {
-; O1-NEXT:    .reg .b16 %rs<3>;
-; O1-NEXT:    .reg .b32 %r<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b16 %rs1, [t32_param_0];
-; O1-NEXT:    ld.param.b16 %rs2, [t32_param_1];
-; O1-NEXT:    mul.wide.s16 %r1, %rs1, %rs2;
-; O1-NEXT:    ld.param.b32 %r2, [t32_param_2];
-; O1-NEXT:    add.s32 %r3, %r2, %r1;
-; O1-NEXT:    st.param.b32 [func_retval0], %r3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t32(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b16 %rs<3>;
+; O1-NO-MAD-NEXT:    .reg .b32 %r<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b16 %rs1, [t32_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b16 %rs2, [t32_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.s16 %r1, %rs1, %rs2;
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t32_param_2];
+; O1-NO-MAD-NEXT:    add.s32 %r3, %r2, %r1;
+; O1-NO-MAD-NEXT:    st.param.b32 [func_retval0], %r3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t32(
+; O1-MAD:       {
+; O1-MAD-NEXT:    .reg .b16 %rs<3>;
+; O1-MAD-NEXT:    .reg .b32 %r<3>;
+; O1-MAD-EMPTY:
+; O1-MAD-NEXT:  // %bb.0:
+; O1-MAD-NEXT:    ld.param.b16 %rs1, [t32_param_0];
+; O1-MAD-NEXT:    ld.param.b16 %rs2, [t32_param_1];
+; O1-MAD-NEXT:    ld.param.b32 %r1, [t32_param_2];
+; O1-MAD-NEXT:    mad.wide.s16 %r2, %rs1, %rs2, %r1;
+; O1-MAD-NEXT:    st.param.b32 [func_retval0], %r2;
+; O1-MAD-NEXT:    ret;
 ;
 ; O0-LABEL: t32(
 ; O0:       {
@@ -971,20 +1039,32 @@ define i32 @t32(i16 %a, i16 %b, i32 %c) {
 }
 
 define i32 @t33(i16 %a, i16 %b, i32 %c) {
-;
-; O1-LABEL: t33(
-; O1:       {
-; O1-NEXT:    .reg .b16 %rs<3>;
-; O1-NEXT:    .reg .b32 %r<4>;
-; O1-EMPTY:
-; O1-NEXT:  // %bb.0:
-; O1-NEXT:    ld.param.b16 %rs1, [t33_param_0];
-; O1-NEXT:    ld.param.b16 %rs2, [t33_param_1];
-; O1-NEXT:    mul.wide.s16 %r1, %rs1, %rs2;
-; O1-NEXT:    ld.param.b32 %r2, [t33_param_2];
-; O1-NEXT:    add.s32 %r3, %r2, %r1;
-; O1-NEXT:    st.param.b32 [func_retval0], %r3;
-; O1-NEXT:    ret;
+; O1-NO-MAD-LABEL: t33(
+; O1-NO-MAD:       {
+; O1-NO-MAD-NEXT:    .reg .b16 %rs<3>;
+; O1-NO-MAD-NEXT:    .reg .b32 %r<4>;
+; O1-NO-MAD-EMPTY:
+; O1-NO-MAD-NEXT:  // %bb.0:
+; O1-NO-MAD-NEXT:    ld.param.b16 %rs1, [t33_param_0];
+; O1-NO-MAD-NEXT:    ld.param.b16 %rs2, [t33_param_1];
+; O1-NO-MAD-NEXT:    mul.wide.s16 %r1, %rs1, %rs2;
+; O1-NO-MAD-NEXT:    ld.param.b32 %r2, [t33_param_2];
+; O1-NO-MAD-NEXT:    add.s32 %r3, %r2, %r1;
+; O1-NO-MAD-NEXT:    st.param.b32 [func_retval0], %r3;
+; O1-NO-MAD-NEXT:    ret;
+;
+; O1-MAD-LABEL: t33(
+; O1-MAD:     ...
[truncated]

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this

@justinfargnoli
Copy link
Contributor Author

justinfargnoli commented Sep 24, 2025

I plan on merging tomorrow morning PST to give time for @Artem-B or @AlexMaclean to chime in.

Copy link
Member

Artem-B commented Sep 24, 2025

Do we have a good idea what exactly triggered the performance regression?
Based on the #155024 (comment)
It appears to be "ptxas can't optimize loads that use mad.wide result as the source address".

If that's the main reason we can't enable IMAD unconditionally, we should make it a prominent FIXME/TODO next to the knob controlling it now.

@justinfargnoli
Copy link
Contributor Author

Do we have a good idea what exactly triggered the performance regression?

No, I haven't had the bandwidth to diagnose the issue properly.


It appears to be "ptxas can't optimize loads that use mad.wide result as the source address".

Apologies if I'm being too pedantic, but I'd guess that ptxas can do this, but for whatever reason, it doesn't do it right now.


we should make it a prominent FIXME/TODO next to the knob controlling it now.

Added a FIXME with 204513f.

Copy link
Member

Artem-B commented Sep 24, 2025

I've poked at the SASS a bit and indeed, ptxas does something rather strange: https://godbolt.org/z/bvb47sWYc

In both cases the base address addition appears to be "baked" into the load instruction itself as UR4 in LDG.E.128 R4, desc[UR4][R4.64]. The offset is specified separately.

I guess ptxas trips on attempting to extract the base address back from the mad.wide, while it can handle it in mul/add case.

Copy link
Member

Artem-B commented Sep 25, 2025

Too bad ptx does not allow specifying non-constant source offset. That would obviate the need for doing multiply/add, only to have ptxas do extra work to extract the base address back.
Meanwhile, we could try limit fusing mul/add into mad if the result is used as the source in the loads/stores on recent GPU variants. I think that's as helpful as we can be in order to generate ptxas-friendly PTX.

@justinfargnoli
Copy link
Contributor Author

Meanwhile, we could try limit fusing mul/add into mad if the result is used as the source in the loads/stores on recent GPU variants.

I'm currently working on this, but it will likely be a couple of weeks before I have the time to finish it up.

@justinfargnoli justinfargnoli merged commit f07cedb into llvm:main Sep 25, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…m#160214)

Users reported regressions to important matmul kernels as a result of
llvm#155024. Although llvm#155024 was a revert, this PR should allow them to
recover some of the lost performance.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants