Skip to content

Conversation

@AlexMaclean
Copy link
Member

These classes are redundant, as the untyped "Int" classes can be used for all float operations. This change is intended to be as minimal as possible and leaves the many potential simplifications and refactors this exposes as future work.

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

These classes are redundant, as the untyped "Int" classes can be used for all float operations. This change is intended to be as minimal as possible and leaves the many potential simplifications and refactors this exposes as future work.


Patch is 1.24 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140487.diff

89 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (-4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+4-7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+2-10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (-8)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll (+53-53)
  • (modified) llvm/test/CodeGen/NVPTX/access-non-generic.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/aggregate-return.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/and-or-setcc.ll (+12-14)
  • (modified) llvm/test/CodeGen/NVPTX/arithmetic-fp-sm20.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/atomics-with-scope.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/atomics.ll (+40-41)
  • (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+486-596)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll (+14-16)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+118-135)
  • (modified) llvm/test/CodeGen/NVPTX/bug22322.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/convert-fp-i8.ll (+29-31)
  • (modified) llvm/test/CodeGen/NVPTX/convert-fp.ll (+26-26)
  • (modified) llvm/test/CodeGen/NVPTX/convert-sm100.ll (+16-20)
  • (modified) llvm/test/CodeGen/NVPTX/convert-sm100a.ll (+60-70)
  • (modified) llvm/test/CodeGen/NVPTX/convert-sm80.ll (+66-77)
  • (modified) llvm/test/CodeGen/NVPTX/convert-sm90.ll (+16-20)
  • (modified) llvm/test/CodeGen/NVPTX/copysign.ll (+34-34)
  • (modified) llvm/test/CodeGen/NVPTX/distributed-shared-cluster.ll (+4-5)
  • (modified) llvm/test/CodeGen/NVPTX/div.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+6-7)
  • (modified) llvm/test/CodeGen/NVPTX/f16-abs.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/f16-instructions.ll (+128-128)
  • (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+435-481)
  • (modified) llvm/test/CodeGen/NVPTX/f32-ex2.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/f32-lg2.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/fabs-intrinsics.ll (+12-12)
  • (modified) llvm/test/CodeGen/NVPTX/fexp2.ll (+116-133)
  • (modified) llvm/test/CodeGen/NVPTX/flog2.ll (+70-82)
  • (modified) llvm/test/CodeGen/NVPTX/fma-assoc.ll (+39-9)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-contract.ll (+311-406)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll (+243-315)
  • (modified) llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll (+504-654)
  • (modified) llvm/test/CodeGen/NVPTX/fma.ll (+94-14)
  • (modified) llvm/test/CodeGen/NVPTX/fp-contract.ll (+40-40)
  • (modified) llvm/test/CodeGen/NVPTX/fp-literals.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/frem.ll (+132-132)
  • (modified) llvm/test/CodeGen/NVPTX/i1-int-to-fp.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+7-9)
  • (modified) llvm/test/CodeGen/NVPTX/inline-asm.ll (+29-2)
  • (modified) llvm/test/CodeGen/NVPTX/intrinsics.ll (+16-16)
  • (modified) llvm/test/CodeGen/NVPTX/ld-generic.ll (+121-24)
  • (modified) llvm/test/CodeGen/NVPTX/ld-st-addrrspace.py (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll (+34-36)
  • (modified) llvm/test/CodeGen/NVPTX/ldg-invariant.ll (+31-33)
  • (modified) llvm/test/CodeGen/NVPTX/ldparam-v4.ll (+18-4)
  • (modified) llvm/test/CodeGen/NVPTX/ldu-ldg.ll (+12-14)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-256-addressing-invariant.ll (+25-27)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-256-addressing.ll (+25-27)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-scalars.ll (+256-288)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-sm-70.ll (+2697-1176)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-sm-90.ll (+1065-456)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll (+102-112)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-vectors.ll (+136-144)
  • (modified) llvm/test/CodeGen/NVPTX/math-intrins.ll (+523-559)
  • (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+7-8)
  • (modified) llvm/test/CodeGen/NVPTX/param-add.ll (+11-13)
  • (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/param-overalign.ll (+16-16)
  • (modified) llvm/test/CodeGen/NVPTX/proxy-reg-erasure-ptx.ll (+7-7)
  • (modified) llvm/test/CodeGen/NVPTX/rcp-opt.ll (+15-15)
  • (modified) llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll (+201-201)
  • (modified) llvm/test/CodeGen/NVPTX/redux-sync-f32.ll (+40-48)
  • (modified) llvm/test/CodeGen/NVPTX/reg-types.ll (+2-4)
  • (modified) llvm/test/CodeGen/NVPTX/shfl-p.ll (+225-80)
  • (modified) llvm/test/CodeGen/NVPTX/shfl-sync-p.ll (+16-16)
  • (modified) llvm/test/CodeGen/NVPTX/shfl.ll (+5-5)
  • (modified) llvm/test/CodeGen/NVPTX/st-addrspace.ll (+12-12)
  • (modified) llvm/test/CodeGen/NVPTX/st-generic.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+83-83)
  • (modified) llvm/test/CodeGen/NVPTX/surf-read-cuda.ll (+6-8)
  • (modified) llvm/test/CodeGen/NVPTX/surf-read.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/surf-tex.py (+5-5)
  • (modified) llvm/test/CodeGen/NVPTX/tag-invariant-loads.ll (+6-8)
  • (modified) llvm/test/CodeGen/NVPTX/tex-read-cuda.ll (+11-14)
  • (modified) llvm/test/CodeGen/NVPTX/tex-read.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/vaargs.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+11-12)
  • (modified) llvm/test/CodeGen/NVPTX/vec-param-load.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/vector-args.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/vector-loads.ll (+37-37)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+2-2)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0e5207cf9b04c..e2e42ff771336 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -223,10 +223,6 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
       Ret = (3 << 28);
     } else if (RC == &NVPTX::Int64RegsRegClass) {
       Ret = (4 << 28);
-    } else if (RC == &NVPTX::Float32RegsRegClass) {
-      Ret = (5 << 28);
-    } else if (RC == &NVPTX::Float64RegsRegClass) {
-      Ret = (6 << 28);
     } else if (RC == &NVPTX::Int128RegsRegClass) {
       Ret = (7 << 28);
     } else {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 82d00ef8eccb9..9a82db31e43a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -596,8 +596,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
-  addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
-  addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
+  addRegisterClass(MVT::f32, &NVPTX::Int32RegsRegClass);
+  addRegisterClass(MVT::f64, &NVPTX::Int64RegsRegClass);
   addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
   addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
@@ -4931,13 +4931,14 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
     case 'b':
       return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
     case 'c':
-      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
     case 'h':
       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
     case 'r':
+    case 'f':
       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
     case 'l':
     case 'N':
+    case 'd':
       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
     case 'q': {
       if (STI.getSmVersion() < 70)
@@ -4945,10 +4946,6 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
                            "supported for sm_70 and higher!");
       return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
     }
-    case 'f':
-      return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
-    case 'd':
-      return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
     }
   }
   return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index 67dc7904a91ae..f262a0fb66c25 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -44,19 +44,11 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   } else if (DestRC == &NVPTX::Int16RegsRegClass) {
     Op = NVPTX::MOV16r;
   } else if (DestRC == &NVPTX::Int32RegsRegClass) {
-    Op = (SrcRC == &NVPTX::Int32RegsRegClass ? NVPTX::IMOV32r
-                                             : NVPTX::BITCONVERT_32_F2I);
+    Op = NVPTX::IMOV32r;
   } else if (DestRC == &NVPTX::Int64RegsRegClass) {
-    Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64r
-                                             : NVPTX::BITCONVERT_64_F2I);
+    Op = NVPTX::IMOV64r;
   } else if (DestRC == &NVPTX::Int128RegsRegClass) {
     Op = NVPTX::IMOV128r;
-  } else if (DestRC == &NVPTX::Float32RegsRegClass) {
-    Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32r
-                                               : NVPTX::BITCONVERT_32_I2F);
-  } else if (DestRC == &NVPTX::Float64RegsRegClass) {
-    Op = (SrcRC == &NVPTX::Float64RegsRegClass ? NVPTX::FMOV64r
-                                               : NVPTX::BITCONVERT_64_I2F);
   } else {
     llvm_unreachable("Bad register copy");
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
index 6b9797c3e6aae..eb60e1502cf90 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -25,10 +25,6 @@ using namespace llvm;
 
 namespace llvm {
 StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
-  if (RC == &NVPTX::Float32RegsRegClass)
-    return ".b32";
-  if (RC == &NVPTX::Float64RegsRegClass)
-    return ".b64";
   if (RC == &NVPTX::Int128RegsRegClass)
     return ".b128";
   if (RC == &NVPTX::Int64RegsRegClass)
@@ -63,10 +59,6 @@ StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
 }
 
 StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) {
-  if (RC == &NVPTX::Float32RegsRegClass)
-    return "%f";
-  if (RC == &NVPTX::Float64RegsRegClass)
-    return "%fd";
   if (RC == &NVPTX::Int128RegsRegClass)
     return "%rq";
   if (RC == &NVPTX::Int64RegsRegClass)
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
index 2011f0f7e328f..2eea9e9721cdf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -40,8 +40,6 @@ foreach i = 0...4 in {
   def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
   def H#i  : NVPTXReg<"%h"#i>;  // 16-bit float
   def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
-  def F#i  : NVPTXReg<"%f"#i>;  // 32-bit float
-  def FL#i : NVPTXReg<"%fd"#i>; // 64-bit float
 
   // Arguments
   def ia#i : NVPTXReg<"%ia"#i>;
@@ -59,14 +57,13 @@ foreach i = 0...31 in {
 //===----------------------------------------------------------------------===//
 def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
 def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
-def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
+def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
                               (add (sequence "R%u", 0, 4),
                               VRFrame32, VRFrameLocal32)>;
-def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
+def Int64Regs : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
 // 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
 def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
-def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
-def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
+
 def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
 def Int64ArgRegs : NVPTXRegClass<[i64], 64, (add (sequence "la%u", 0, 4))>;
 def Float32ArgRegs : NVPTXRegClass<[f32], 32, (add (sequence "fa%u", 0, 4))>;
@@ -75,3 +72,6 @@ def Float64ArgRegs : NVPTXRegClass<[f64], 64, (add (sequence "da%u", 0, 4))>;
 // Read NVPTXRegisterInfo.cpp to see how VRFrame and VRDepot are used.
 def SpecialRegs : NVPTXRegClass<[i32], 32, (add VRFrame32, VRFrameLocal32, VRDepot,
                                             (sequence "ENVREG%u", 0, 31))>;
+
+defvar Float32Regs = Int32Regs;
+defvar Float64Regs = Int64Regs;
diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
index 78b57badc06e8..1207c429524ca 100644
--- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
+++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
@@ -45,36 +45,36 @@ define half @fh(ptr %p) {
 ; ENABLED-LABEL: fh(
 ; ENABLED:       {
 ; ENABLED-NEXT:    .reg .b16 %rs<10>;
-; ENABLED-NEXT:    .reg .b32 %f<13>;
+; ENABLED-NEXT:    .reg .b32 %r<13>;
 ; ENABLED-NEXT:    .reg .b64 %rd<2>;
 ; ENABLED-EMPTY:
 ; ENABLED-NEXT:  // %bb.0:
 ; ENABLED-NEXT:    ld.param.b64 %rd1, [fh_param_0];
 ; ENABLED-NEXT:    ld.v4.b16 {%rs1, %rs2, %rs3, %rs4}, [%rd1];
 ; ENABLED-NEXT:    ld.b16 %rs5, [%rd1+8];
-; ENABLED-NEXT:    cvt.f32.f16 %f1, %rs2;
-; ENABLED-NEXT:    cvt.f32.f16 %f2, %rs1;
-; ENABLED-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; ENABLED-NEXT:    cvt.rn.f16.f32 %rs6, %f3;
-; ENABLED-NEXT:    cvt.f32.f16 %f4, %rs4;
-; ENABLED-NEXT:    cvt.f32.f16 %f5, %rs3;
-; ENABLED-NEXT:    add.rn.f32 %f6, %f5, %f4;
-; ENABLED-NEXT:    cvt.rn.f16.f32 %rs7, %f6;
-; ENABLED-NEXT:    cvt.f32.f16 %f7, %rs7;
-; ENABLED-NEXT:    cvt.f32.f16 %f8, %rs6;
-; ENABLED-NEXT:    add.rn.f32 %f9, %f8, %f7;
-; ENABLED-NEXT:    cvt.rn.f16.f32 %rs8, %f9;
-; ENABLED-NEXT:    cvt.f32.f16 %f10, %rs8;
-; ENABLED-NEXT:    cvt.f32.f16 %f11, %rs5;
-; ENABLED-NEXT:    add.rn.f32 %f12, %f10, %f11;
-; ENABLED-NEXT:    cvt.rn.f16.f32 %rs9, %f12;
+; ENABLED-NEXT:    cvt.f32.f16 %r1, %rs2;
+; ENABLED-NEXT:    cvt.f32.f16 %r2, %rs1;
+; ENABLED-NEXT:    add.rn.f32 %r3, %r2, %r1;
+; ENABLED-NEXT:    cvt.rn.f16.f32 %rs6, %r3;
+; ENABLED-NEXT:    cvt.f32.f16 %r4, %rs4;
+; ENABLED-NEXT:    cvt.f32.f16 %r5, %rs3;
+; ENABLED-NEXT:    add.rn.f32 %r6, %r5, %r4;
+; ENABLED-NEXT:    cvt.rn.f16.f32 %rs7, %r6;
+; ENABLED-NEXT:    cvt.f32.f16 %r7, %rs7;
+; ENABLED-NEXT:    cvt.f32.f16 %r8, %rs6;
+; ENABLED-NEXT:    add.rn.f32 %r9, %r8, %r7;
+; ENABLED-NEXT:    cvt.rn.f16.f32 %rs8, %r9;
+; ENABLED-NEXT:    cvt.f32.f16 %r10, %rs8;
+; ENABLED-NEXT:    cvt.f32.f16 %r11, %rs5;
+; ENABLED-NEXT:    add.rn.f32 %r12, %r10, %r11;
+; ENABLED-NEXT:    cvt.rn.f16.f32 %rs9, %r12;
 ; ENABLED-NEXT:    st.param.b16 [func_retval0], %rs9;
 ; ENABLED-NEXT:    ret;
 ;
 ; DISABLED-LABEL: fh(
 ; DISABLED:       {
 ; DISABLED-NEXT:    .reg .b16 %rs<10>;
-; DISABLED-NEXT:    .reg .b32 %f<13>;
+; DISABLED-NEXT:    .reg .b32 %r<13>;
 ; DISABLED-NEXT:    .reg .b64 %rd<2>;
 ; DISABLED-EMPTY:
 ; DISABLED-NEXT:  // %bb.0:
@@ -84,22 +84,22 @@ define half @fh(ptr %p) {
 ; DISABLED-NEXT:    ld.b16 %rs3, [%rd1+4];
 ; DISABLED-NEXT:    ld.b16 %rs4, [%rd1+6];
 ; DISABLED-NEXT:    ld.b16 %rs5, [%rd1+8];
-; DISABLED-NEXT:    cvt.f32.f16 %f1, %rs2;
-; DISABLED-NEXT:    cvt.f32.f16 %f2, %rs1;
-; DISABLED-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; DISABLED-NEXT:    cvt.rn.f16.f32 %rs6, %f3;
-; DISABLED-NEXT:    cvt.f32.f16 %f4, %rs4;
-; DISABLED-NEXT:    cvt.f32.f16 %f5, %rs3;
-; DISABLED-NEXT:    add.rn.f32 %f6, %f5, %f4;
-; DISABLED-NEXT:    cvt.rn.f16.f32 %rs7, %f6;
-; DISABLED-NEXT:    cvt.f32.f16 %f7, %rs7;
-; DISABLED-NEXT:    cvt.f32.f16 %f8, %rs6;
-; DISABLED-NEXT:    add.rn.f32 %f9, %f8, %f7;
-; DISABLED-NEXT:    cvt.rn.f16.f32 %rs8, %f9;
-; DISABLED-NEXT:    cvt.f32.f16 %f10, %rs8;
-; DISABLED-NEXT:    cvt.f32.f16 %f11, %rs5;
-; DISABLED-NEXT:    add.rn.f32 %f12, %f10, %f11;
-; DISABLED-NEXT:    cvt.rn.f16.f32 %rs9, %f12;
+; DISABLED-NEXT:    cvt.f32.f16 %r1, %rs2;
+; DISABLED-NEXT:    cvt.f32.f16 %r2, %rs1;
+; DISABLED-NEXT:    add.rn.f32 %r3, %r2, %r1;
+; DISABLED-NEXT:    cvt.rn.f16.f32 %rs6, %r3;
+; DISABLED-NEXT:    cvt.f32.f16 %r4, %rs4;
+; DISABLED-NEXT:    cvt.f32.f16 %r5, %rs3;
+; DISABLED-NEXT:    add.rn.f32 %r6, %r5, %r4;
+; DISABLED-NEXT:    cvt.rn.f16.f32 %rs7, %r6;
+; DISABLED-NEXT:    cvt.f32.f16 %r7, %rs7;
+; DISABLED-NEXT:    cvt.f32.f16 %r8, %rs6;
+; DISABLED-NEXT:    add.rn.f32 %r9, %r8, %r7;
+; DISABLED-NEXT:    cvt.rn.f16.f32 %rs8, %r9;
+; DISABLED-NEXT:    cvt.f32.f16 %r10, %rs8;
+; DISABLED-NEXT:    cvt.f32.f16 %r11, %rs5;
+; DISABLED-NEXT:    add.rn.f32 %r12, %r10, %r11;
+; DISABLED-NEXT:    cvt.rn.f16.f32 %rs9, %r12;
 ; DISABLED-NEXT:    st.param.b16 [func_retval0], %rs9;
 ; DISABLED-NEXT:    ret;
   %p.1 = getelementptr half, ptr %p, i32 1
@@ -121,37 +121,37 @@ define half @fh(ptr %p) {
 define float @ff(ptr %p) {
 ; ENABLED-LABEL: ff(
 ; ENABLED:       {
-; ENABLED-NEXT:    .reg .b32 %f<10>;
+; ENABLED-NEXT:    .reg .b32 %r<10>;
 ; ENABLED-NEXT:    .reg .b64 %rd<2>;
 ; ENABLED-EMPTY:
 ; ENABLED-NEXT:  // %bb.0:
 ; ENABLED-NEXT:    ld.param.b64 %rd1, [ff_param_0];
-; ENABLED-NEXT:    ld.v4.b32 {%f1, %f2, %f3, %f4}, [%rd1];
-; ENABLED-NEXT:    ld.b32 %f5, [%rd1+16];
-; ENABLED-NEXT:    add.rn.f32 %f6, %f1, %f2;
-; ENABLED-NEXT:    add.rn.f32 %f7, %f3, %f4;
-; ENABLED-NEXT:    add.rn.f32 %f8, %f6, %f7;
-; ENABLED-NEXT:    add.rn.f32 %f9, %f8, %f5;
-; ENABLED-NEXT:    st.param.b32 [func_retval0], %f9;
+; ENABLED-NEXT:    ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; ENABLED-NEXT:    ld.b32 %r5, [%rd1+16];
+; ENABLED-NEXT:    add.rn.f32 %r6, %r1, %r2;
+; ENABLED-NEXT:    add.rn.f32 %r7, %r3, %r4;
+; ENABLED-NEXT:    add.rn.f32 %r8, %r6, %r7;
+; ENABLED-NEXT:    add.rn.f32 %r9, %r8, %r5;
+; ENABLED-NEXT:    st.param.b32 [func_retval0], %r9;
 ; ENABLED-NEXT:    ret;
 ;
 ; DISABLED-LABEL: ff(
 ; DISABLED:       {
-; DISABLED-NEXT:    .reg .b32 %f<10>;
+; DISABLED-NEXT:    .reg .b32 %r<10>;
 ; DISABLED-NEXT:    .reg .b64 %rd<2>;
 ; DISABLED-EMPTY:
 ; DISABLED-NEXT:  // %bb.0:
 ; DISABLED-NEXT:    ld.param.b64 %rd1, [ff_param_0];
-; DISABLED-NEXT:    ld.b32 %f1, [%rd1];
-; DISABLED-NEXT:    ld.b32 %f2, [%rd1+4];
-; DISABLED-NEXT:    ld.b32 %f3, [%rd1+8];
-; DISABLED-NEXT:    ld.b32 %f4, [%rd1+12];
-; DISABLED-NEXT:    ld.b32 %f5, [%rd1+16];
-; DISABLED-NEXT:    add.rn.f32 %f6, %f1, %f2;
-; DISABLED-NEXT:    add.rn.f32 %f7, %f3, %f4;
-; DISABLED-NEXT:    add.rn.f32 %f8, %f6, %f7;
-; DISABLED-NEXT:    add.rn.f32 %f9, %f8, %f5;
-; DISABLED-NEXT:    st.param.b32 [func_retval0], %f9;
+; DISABLED-NEXT:    ld.b32 %r1, [%rd1];
+; DISABLED-NEXT:    ld.b32 %r2, [%rd1+4];
+; DISABLED-NEXT:    ld.b32 %r3, [%rd1+8];
+; DISABLED-NEXT:    ld.b32 %r4, [%rd1+12];
+; DISABLED-NEXT:    ld.b32 %r5, [%rd1+16];
+; DISABLED-NEXT:    add.rn.f32 %r6, %r1, %r2;
+; DISABLED-NEXT:    add.rn.f32 %r7, %r3, %r4;
+; DISABLED-NEXT:    add.rn.f32 %r8, %r6, %r7;
+; DISABLED-NEXT:    add.rn.f32 %r9, %r8, %r5;
+; DISABLED-NEXT:    st.param.b32 [func_retval0], %r9;
 ; DISABLED-NEXT:    ret;
   %p.1 = getelementptr float, ptr %p, i32 1
   %p.2 = getelementptr float, ptr %p, i32 2
diff --git a/llvm/test/CodeGen/NVPTX/access-non-generic.ll b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
index a816f2e84b064..9edd4de017ee2 100644
--- a/llvm/test/CodeGen/NVPTX/access-non-generic.ll
+++ b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
@@ -23,10 +23,10 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
   ; load cast
   %1 = load float, ptr addrspacecast (ptr addrspace(3) @scalar to ptr), align 4
   call void @use(float %1)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [scalar];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [scalar];
   ; store cast
   store float %v, ptr addrspacecast (ptr addrspace(3) @scalar to ptr), align 4
-; PTX: st.shared.b32 [scalar], %f{{[0-9]+}};
+; PTX: st.shared.b32 [scalar], %r{{[0-9]+}};
   ; use syncthreads to disable optimizations across components
   call void @llvm.nvvm.barrier0()
 ; PTX: bar.sync 0;
@@ -35,20 +35,20 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
   %2 = addrspacecast ptr addrspace(3) @scalar to ptr
   %3 = load float, ptr %2, align 4
   call void @use(float %3)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [scalar];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [scalar];
   ; cast; store
   store float %v, ptr %2, align 4
-; PTX: st.shared.b32 [scalar], %f{{[0-9]+}};
+; PTX: st.shared.b32 [scalar], %r{{[0-9]+}};
   call void @llvm.nvvm.barrier0()
 ; PTX: bar.sync 0;
 
   ; load gep cast
   %4 = load float, ptr getelementptr inbounds ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i32 0, i32 5), align 4
   call void @use(float %4)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [array+20];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [array+20];
   ; store gep cast
   store float %v, ptr getelementptr inbounds ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i32 0, i32 5), align 4
-; PTX: st.shared.b32 [array+20], %f{{[0-9]+}};
+; PTX: st.shared.b32 [array+20], %r{{[0-9]+}};
   call void @llvm.nvvm.barrier0()
 ; PTX: bar.sync 0;
 
@@ -56,10 +56,10 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
   %5 = getelementptr inbounds [10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i32 0, i32 5
   %6 = load float, ptr %5, align 4
   call void @use(float %6)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [array+20];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [array+20];
   ; gep cast; store
   store float %v, ptr %5, align 4
-; PTX: st.shared.b32 [array+20], %f{{[0-9]+}};
+; PTX: st.shared.b32 [array+20], %r{{[0-9]+}};
   call void @llvm.nvvm.barrier0()
 ; PTX: bar.sync 0;
 
@@ -68,10 +68,10 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
   %8 = getelementptr inbounds [10 x float], ptr %7, i32 0, i32 %i
   %9 = load float, ptr %8, align 4
   call void @use(float %9)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [%{{(r|rl|rd)[0-9]+}}];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [%{{(r|rl|rd)[0-9]+}}];
   ; cast; gep; store
   store float %v, ptr %8, align 4
-; PTX: st.shared.b32 [%{{(r|rl|rd)[0-9]+}}], %f{{[0-9]+}};
+; PTX: st.shared.b32 [%{{(r|rl|rd)[0-9]+}}], %r{{[0-9]+}};
   call void @llvm.nvvm.barrier0()
 ; PTX: bar.sync 0;
 
diff --git a/llvm/test/CodeGen/NVPTX/aggregate-return.ll b/llvm/test/CodeGen/NVPTX/aggregate-return.ll
index 72c302433f081..1c8f019922e37 100644
--- a/llvm/test/CodeGen/NVPTX/aggregate-return.ll
+++ b/llvm/test/CodeGen/NVPTX/aggregate-return.ll
@@ -10,7 +10,7 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
 ; CHECK-LABEL: @test_v2f32
   %call = tail call <2 x float> @barv(<2 x float> %input)
 ; CHECK: .param .align 8 .b8 retval0[8];
-; CHECK: ld.param.v2.b32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
+; CHECK: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
   store <2 x float> %call, ptr %output, align 8
 ; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
   ret void
@@ -21,10 +21,10 @@ define void @test_v3f32(<3 x float> %input, ptr %output) {
 ;
   %call = tail call <3 x float> @barv3(<3 x float> %input)
 ; CHECK: .param .align 16 .b8 retval0[16];
-; CHECK-DAG: ld.param.v2.b32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
-; CHECK-DAG: ld.param.b32 [[E2:%f[0-9]+]], [retval0+8];
+; CHECK-DAG: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
+; CHECK-DAG: ld.param.b32 [[E2:%r[0-9]+]], [retval0+8];
 ; Make sure we don't load more values than than we need to.
-; CHECK-NOT: ld.param.b32 [[E3:%f[0-9]+]], [retval0+12];
+; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
   store <3 x float> %call, ptr %output, align 8
 ; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
 ; -- This is suboptimal. We should do st.v2.f32 instead
@@ -38,8 +38,8 @@ define void @test_a2f32([2 x float] %input, ptr %output) {
 ; CHECK-LABEL: @test_a2f32
   %call = tail call [2 x float] @bara([2 x float] %input)
 ; CHECK: .param .align 4 .b8 retval0[8];
-; CHECK-DAG: ld.param.b32 [[ELEMA1:%f[0-9]+]], [retval0];
-; CHECK-DAG: ld.param.b32 [[ELEMA2:%f[0-9]+]], [retval0+4];
+; CHECK-DAG: ld.param.b32 [[ELEMA1:%r[0-9]+]], [retval0];
+; CHECK-DAG: ld.param.b32 [[ELEMA2:%r[0-9]+]], [retval0+4];
   store [2 x float] %call, ptr %output, align 4
 ; CHECK: }
 ; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMA1]]
@@ -52,8 +52,8 @@ define void @test_s2f32({float, float} %input, ptr %output) {
 ; CHECK-LABEL: @test_s2f32
   %call = tail call {float, float} @bars({float, float} %input)
 ; CHECK: .param .align 4 .b8 retval0[8];
-; CHECK-DAG: ld.param.b32 [[ELEMS1:%f[0-9]+]], [retval0];
-; CHECK-DAG: ld.param.b32 [[ELEMS2:%f[0-9]+]], [retval0+4];
+; CHECK-DAG: ld.param.b32 [[ELEMS1:%r[0-9]+]], [retval0];
+; CHECK-DAG: ld.param.b32 [[ELEMS2:%r[0-9]+]], [retval0+4];
   store {float, float} %call, ptr %output, align 4
 ; CHECK: }
 ; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMS1]]
diff --git a/llvm/test/CodeGen/NVPTX/and-or-setcc.ll b/llvm/test/CodeGen/NVPTX/and-or-setcc.ll
index 53c741bd6cb2c..b7e6e8b85298a 100644
--- a/llvm/test/CodeGen/NVPTX/and-or-setcc.ll
+++ b/llvm/test/CodeGen/NVPTX/and-or-setcc.ll
@@ -8,15 +8,14 @@ define i1 @and_ord(float %a, float %b) {
 ; CHECK-LABEL: and_ord(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .pred %p<2>;
-; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b32 %f<3>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b32 %f1, [and_ord_param_0];
-; CHECK-NEXT:    ld.param.b32 %f2, [and_ord_param_1];
-; CHECK-NEXT:    setp.num.f32 %p1, %f1, %f2;
-; CHECK-NEXT:    selp.b32 %r1, 1, 0, %p1;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ld.param.b32 %r1, [and_ord_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [and_ord_param_1];
+; CHECK-NEXT:    setp.num.f32 %p1, %r1, %r2;
+; CHECK-NEXT:    selp.b32 %r3, 1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retv...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I've checked how far back ptxas supports FP operations on .b32/.b64 registers, and it appears to work in CUDA versions as old as 9.1: https://godbolt.org/z/nbvPe57dc

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/untyped-rc branch from 101c2d0 to 916ac4f Compare May 21, 2025 16:41
@AlexMaclean AlexMaclean merged commit 76c9bfe into llvm:main May 21, 2025
11 checks passed
@Artem-B
Copy link
Member

Artem-B commented Jun 3, 2025

@AlexMaclean Just a FYI.

We've ran into an interesting issue in ptxas triggered by this change.

It appears that uniform use of .b32 registers for integer and FP operations may trigger some sort of regression in ptxas.

We observed that ons some kernels, where ptxas was able to eliminate local storage before the patch, it no longer does so.
The PTX is identical before/after this patch, modulo the number of registers and their names, but with the mixed fp32/b32 registers, generated SASS had no local memory loads/stores, while after the patch, the local memory continues to be used.

Unfortunately I can't share the kernel, but I will try to reduce the PTX to something that can be used to demonstrate the issue.

@AlexMaclean
Copy link
Member Author

@AlexMaclean Just a FYI.

We've ran into an interesting issue in ptxas triggered by this change.

It appears that uniform use of .b32 registers for integer and FP operations may trigger some sort of regression in ptxas.

We observed that ons some kernels, where ptxas was able to eliminate local storage before the patch, it no longer does so. The PTX is identical before/after this patch, modulo the number of registers and their names, but with the mixed fp32/b32 registers, generated SASS had no local memory loads/stores, while after the patch, the local memory continues to be used.

Unfortunately I can't share the kernel, but I will try to reduce the PTX to something that can be used to demonstrate the issue.

Yikes! that is unfortunate. Is the regressions significant and widespread enough that we'd need to think about reverting this change? Or could it wait for a fix to ptxas in some future release of CUDA? Alternately, would it be possible to eliminate the local memory in LLVM via SROA?

@Artem-B
Copy link
Member

Artem-B commented Jun 3, 2025

Most of the work on eliminating local storage falls on LLVM. In most of the kernels that still use local memory, ptxas usually can't do much about it either. This particular kernel happened to be the rare occasion where ptxas was previously able to remove a small 8-byte chunk of local data that LLVM kept around. AFAICT, such cases are pretty rare. So far I've see only one TU with a few kernels, and was able to work around it on the source level (helped LLVM to get rid of the alloca).

davidberard98 added a commit to davidberard98/llvm-project that referenced this pull request Jul 9, 2025
@davidberard98
Copy link
Contributor

FYI, we saw a ~4x regression in a layernorm backward kernel in Triton that bisects to this PR: details here meta-pytorch/tritonbench#264

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants