-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[NVPTX] Remove Float register classes #140487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Remove Float register classes #140487
Conversation
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesThese classes are redundant, as the untyped "Int" classes can be used for all float operations. This change is intended to be as minimal as possible and leaves the many potential simplifications and refactors this exposes as future work. Patch is 1.24 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140487.diff 89 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 0e5207cf9b04c..e2e42ff771336 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -223,10 +223,6 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Ret = (3 << 28);
} else if (RC == &NVPTX::Int64RegsRegClass) {
Ret = (4 << 28);
- } else if (RC == &NVPTX::Float32RegsRegClass) {
- Ret = (5 << 28);
- } else if (RC == &NVPTX::Float64RegsRegClass) {
- Ret = (6 << 28);
} else if (RC == &NVPTX::Int128RegsRegClass) {
Ret = (7 << 28);
} else {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 82d00ef8eccb9..9a82db31e43a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -596,8 +596,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
- addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
- addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
+ addRegisterClass(MVT::f32, &NVPTX::Int32RegsRegClass);
+ addRegisterClass(MVT::f64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
@@ -4931,13 +4931,14 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
case 'b':
return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
case 'c':
- return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
case 'h':
return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
case 'r':
+ case 'f':
return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
case 'l':
case 'N':
+ case 'd':
return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
case 'q': {
if (STI.getSmVersion() < 70)
@@ -4945,10 +4946,6 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
"supported for sm_70 and higher!");
return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
}
- case 'f':
- return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
- case 'd':
- return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
}
}
return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index 67dc7904a91ae..f262a0fb66c25 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -44,19 +44,11 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
} else if (DestRC == &NVPTX::Int16RegsRegClass) {
Op = NVPTX::MOV16r;
} else if (DestRC == &NVPTX::Int32RegsRegClass) {
- Op = (SrcRC == &NVPTX::Int32RegsRegClass ? NVPTX::IMOV32r
- : NVPTX::BITCONVERT_32_F2I);
+ Op = NVPTX::IMOV32r;
} else if (DestRC == &NVPTX::Int64RegsRegClass) {
- Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64r
- : NVPTX::BITCONVERT_64_F2I);
+ Op = NVPTX::IMOV64r;
} else if (DestRC == &NVPTX::Int128RegsRegClass) {
Op = NVPTX::IMOV128r;
- } else if (DestRC == &NVPTX::Float32RegsRegClass) {
- Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32r
- : NVPTX::BITCONVERT_32_I2F);
- } else if (DestRC == &NVPTX::Float64RegsRegClass) {
- Op = (SrcRC == &NVPTX::Float64RegsRegClass ? NVPTX::FMOV64r
- : NVPTX::BITCONVERT_64_I2F);
} else {
llvm_unreachable("Bad register copy");
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
index 6b9797c3e6aae..eb60e1502cf90 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -25,10 +25,6 @@ using namespace llvm;
namespace llvm {
StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
- if (RC == &NVPTX::Float32RegsRegClass)
- return ".b32";
- if (RC == &NVPTX::Float64RegsRegClass)
- return ".b64";
if (RC == &NVPTX::Int128RegsRegClass)
return ".b128";
if (RC == &NVPTX::Int64RegsRegClass)
@@ -63,10 +59,6 @@ StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
}
StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) {
- if (RC == &NVPTX::Float32RegsRegClass)
- return "%f";
- if (RC == &NVPTX::Float64RegsRegClass)
- return "%fd";
if (RC == &NVPTX::Int128RegsRegClass)
return "%rq";
if (RC == &NVPTX::Int64RegsRegClass)
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
index 2011f0f7e328f..2eea9e9721cdf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -40,8 +40,6 @@ foreach i = 0...4 in {
def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
def H#i : NVPTXReg<"%h"#i>; // 16-bit float
def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
- def F#i : NVPTXReg<"%f"#i>; // 32-bit float
- def FL#i : NVPTXReg<"%fd"#i>; // 64-bit float
// Arguments
def ia#i : NVPTXReg<"%ia"#i>;
@@ -59,14 +57,13 @@ foreach i = 0...31 in {
//===----------------------------------------------------------------------===//
def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
-def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
+def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
-def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
+def Int64Regs : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
-def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
-def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
+
def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
def Int64ArgRegs : NVPTXRegClass<[i64], 64, (add (sequence "la%u", 0, 4))>;
def Float32ArgRegs : NVPTXRegClass<[f32], 32, (add (sequence "fa%u", 0, 4))>;
@@ -75,3 +72,6 @@ def Float64ArgRegs : NVPTXRegClass<[f64], 64, (add (sequence "da%u", 0, 4))>;
// Read NVPTXRegisterInfo.cpp to see how VRFrame and VRDepot are used.
def SpecialRegs : NVPTXRegClass<[i32], 32, (add VRFrame32, VRFrameLocal32, VRDepot,
(sequence "ENVREG%u", 0, 31))>;
+
+defvar Float32Regs = Int32Regs;
+defvar Float64Regs = Int64Regs;
diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
index 78b57badc06e8..1207c429524ca 100644
--- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
+++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
@@ -45,36 +45,36 @@ define half @fh(ptr %p) {
; ENABLED-LABEL: fh(
; ENABLED: {
; ENABLED-NEXT: .reg .b16 %rs<10>;
-; ENABLED-NEXT: .reg .b32 %f<13>;
+; ENABLED-NEXT: .reg .b32 %r<13>;
; ENABLED-NEXT: .reg .b64 %rd<2>;
; ENABLED-EMPTY:
; ENABLED-NEXT: // %bb.0:
; ENABLED-NEXT: ld.param.b64 %rd1, [fh_param_0];
; ENABLED-NEXT: ld.v4.b16 {%rs1, %rs2, %rs3, %rs4}, [%rd1];
; ENABLED-NEXT: ld.b16 %rs5, [%rd1+8];
-; ENABLED-NEXT: cvt.f32.f16 %f1, %rs2;
-; ENABLED-NEXT: cvt.f32.f16 %f2, %rs1;
-; ENABLED-NEXT: add.rn.f32 %f3, %f2, %f1;
-; ENABLED-NEXT: cvt.rn.f16.f32 %rs6, %f3;
-; ENABLED-NEXT: cvt.f32.f16 %f4, %rs4;
-; ENABLED-NEXT: cvt.f32.f16 %f5, %rs3;
-; ENABLED-NEXT: add.rn.f32 %f6, %f5, %f4;
-; ENABLED-NEXT: cvt.rn.f16.f32 %rs7, %f6;
-; ENABLED-NEXT: cvt.f32.f16 %f7, %rs7;
-; ENABLED-NEXT: cvt.f32.f16 %f8, %rs6;
-; ENABLED-NEXT: add.rn.f32 %f9, %f8, %f7;
-; ENABLED-NEXT: cvt.rn.f16.f32 %rs8, %f9;
-; ENABLED-NEXT: cvt.f32.f16 %f10, %rs8;
-; ENABLED-NEXT: cvt.f32.f16 %f11, %rs5;
-; ENABLED-NEXT: add.rn.f32 %f12, %f10, %f11;
-; ENABLED-NEXT: cvt.rn.f16.f32 %rs9, %f12;
+; ENABLED-NEXT: cvt.f32.f16 %r1, %rs2;
+; ENABLED-NEXT: cvt.f32.f16 %r2, %rs1;
+; ENABLED-NEXT: add.rn.f32 %r3, %r2, %r1;
+; ENABLED-NEXT: cvt.rn.f16.f32 %rs6, %r3;
+; ENABLED-NEXT: cvt.f32.f16 %r4, %rs4;
+; ENABLED-NEXT: cvt.f32.f16 %r5, %rs3;
+; ENABLED-NEXT: add.rn.f32 %r6, %r5, %r4;
+; ENABLED-NEXT: cvt.rn.f16.f32 %rs7, %r6;
+; ENABLED-NEXT: cvt.f32.f16 %r7, %rs7;
+; ENABLED-NEXT: cvt.f32.f16 %r8, %rs6;
+; ENABLED-NEXT: add.rn.f32 %r9, %r8, %r7;
+; ENABLED-NEXT: cvt.rn.f16.f32 %rs8, %r9;
+; ENABLED-NEXT: cvt.f32.f16 %r10, %rs8;
+; ENABLED-NEXT: cvt.f32.f16 %r11, %rs5;
+; ENABLED-NEXT: add.rn.f32 %r12, %r10, %r11;
+; ENABLED-NEXT: cvt.rn.f16.f32 %rs9, %r12;
; ENABLED-NEXT: st.param.b16 [func_retval0], %rs9;
; ENABLED-NEXT: ret;
;
; DISABLED-LABEL: fh(
; DISABLED: {
; DISABLED-NEXT: .reg .b16 %rs<10>;
-; DISABLED-NEXT: .reg .b32 %f<13>;
+; DISABLED-NEXT: .reg .b32 %r<13>;
; DISABLED-NEXT: .reg .b64 %rd<2>;
; DISABLED-EMPTY:
; DISABLED-NEXT: // %bb.0:
@@ -84,22 +84,22 @@ define half @fh(ptr %p) {
; DISABLED-NEXT: ld.b16 %rs3, [%rd1+4];
; DISABLED-NEXT: ld.b16 %rs4, [%rd1+6];
; DISABLED-NEXT: ld.b16 %rs5, [%rd1+8];
-; DISABLED-NEXT: cvt.f32.f16 %f1, %rs2;
-; DISABLED-NEXT: cvt.f32.f16 %f2, %rs1;
-; DISABLED-NEXT: add.rn.f32 %f3, %f2, %f1;
-; DISABLED-NEXT: cvt.rn.f16.f32 %rs6, %f3;
-; DISABLED-NEXT: cvt.f32.f16 %f4, %rs4;
-; DISABLED-NEXT: cvt.f32.f16 %f5, %rs3;
-; DISABLED-NEXT: add.rn.f32 %f6, %f5, %f4;
-; DISABLED-NEXT: cvt.rn.f16.f32 %rs7, %f6;
-; DISABLED-NEXT: cvt.f32.f16 %f7, %rs7;
-; DISABLED-NEXT: cvt.f32.f16 %f8, %rs6;
-; DISABLED-NEXT: add.rn.f32 %f9, %f8, %f7;
-; DISABLED-NEXT: cvt.rn.f16.f32 %rs8, %f9;
-; DISABLED-NEXT: cvt.f32.f16 %f10, %rs8;
-; DISABLED-NEXT: cvt.f32.f16 %f11, %rs5;
-; DISABLED-NEXT: add.rn.f32 %f12, %f10, %f11;
-; DISABLED-NEXT: cvt.rn.f16.f32 %rs9, %f12;
+; DISABLED-NEXT: cvt.f32.f16 %r1, %rs2;
+; DISABLED-NEXT: cvt.f32.f16 %r2, %rs1;
+; DISABLED-NEXT: add.rn.f32 %r3, %r2, %r1;
+; DISABLED-NEXT: cvt.rn.f16.f32 %rs6, %r3;
+; DISABLED-NEXT: cvt.f32.f16 %r4, %rs4;
+; DISABLED-NEXT: cvt.f32.f16 %r5, %rs3;
+; DISABLED-NEXT: add.rn.f32 %r6, %r5, %r4;
+; DISABLED-NEXT: cvt.rn.f16.f32 %rs7, %r6;
+; DISABLED-NEXT: cvt.f32.f16 %r7, %rs7;
+; DISABLED-NEXT: cvt.f32.f16 %r8, %rs6;
+; DISABLED-NEXT: add.rn.f32 %r9, %r8, %r7;
+; DISABLED-NEXT: cvt.rn.f16.f32 %rs8, %r9;
+; DISABLED-NEXT: cvt.f32.f16 %r10, %rs8;
+; DISABLED-NEXT: cvt.f32.f16 %r11, %rs5;
+; DISABLED-NEXT: add.rn.f32 %r12, %r10, %r11;
+; DISABLED-NEXT: cvt.rn.f16.f32 %rs9, %r12;
; DISABLED-NEXT: st.param.b16 [func_retval0], %rs9;
; DISABLED-NEXT: ret;
%p.1 = getelementptr half, ptr %p, i32 1
@@ -121,37 +121,37 @@ define half @fh(ptr %p) {
define float @ff(ptr %p) {
; ENABLED-LABEL: ff(
; ENABLED: {
-; ENABLED-NEXT: .reg .b32 %f<10>;
+; ENABLED-NEXT: .reg .b32 %r<10>;
; ENABLED-NEXT: .reg .b64 %rd<2>;
; ENABLED-EMPTY:
; ENABLED-NEXT: // %bb.0:
; ENABLED-NEXT: ld.param.b64 %rd1, [ff_param_0];
-; ENABLED-NEXT: ld.v4.b32 {%f1, %f2, %f3, %f4}, [%rd1];
-; ENABLED-NEXT: ld.b32 %f5, [%rd1+16];
-; ENABLED-NEXT: add.rn.f32 %f6, %f1, %f2;
-; ENABLED-NEXT: add.rn.f32 %f7, %f3, %f4;
-; ENABLED-NEXT: add.rn.f32 %f8, %f6, %f7;
-; ENABLED-NEXT: add.rn.f32 %f9, %f8, %f5;
-; ENABLED-NEXT: st.param.b32 [func_retval0], %f9;
+; ENABLED-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
+; ENABLED-NEXT: ld.b32 %r5, [%rd1+16];
+; ENABLED-NEXT: add.rn.f32 %r6, %r1, %r2;
+; ENABLED-NEXT: add.rn.f32 %r7, %r3, %r4;
+; ENABLED-NEXT: add.rn.f32 %r8, %r6, %r7;
+; ENABLED-NEXT: add.rn.f32 %r9, %r8, %r5;
+; ENABLED-NEXT: st.param.b32 [func_retval0], %r9;
; ENABLED-NEXT: ret;
;
; DISABLED-LABEL: ff(
; DISABLED: {
-; DISABLED-NEXT: .reg .b32 %f<10>;
+; DISABLED-NEXT: .reg .b32 %r<10>;
; DISABLED-NEXT: .reg .b64 %rd<2>;
; DISABLED-EMPTY:
; DISABLED-NEXT: // %bb.0:
; DISABLED-NEXT: ld.param.b64 %rd1, [ff_param_0];
-; DISABLED-NEXT: ld.b32 %f1, [%rd1];
-; DISABLED-NEXT: ld.b32 %f2, [%rd1+4];
-; DISABLED-NEXT: ld.b32 %f3, [%rd1+8];
-; DISABLED-NEXT: ld.b32 %f4, [%rd1+12];
-; DISABLED-NEXT: ld.b32 %f5, [%rd1+16];
-; DISABLED-NEXT: add.rn.f32 %f6, %f1, %f2;
-; DISABLED-NEXT: add.rn.f32 %f7, %f3, %f4;
-; DISABLED-NEXT: add.rn.f32 %f8, %f6, %f7;
-; DISABLED-NEXT: add.rn.f32 %f9, %f8, %f5;
-; DISABLED-NEXT: st.param.b32 [func_retval0], %f9;
+; DISABLED-NEXT: ld.b32 %r1, [%rd1];
+; DISABLED-NEXT: ld.b32 %r2, [%rd1+4];
+; DISABLED-NEXT: ld.b32 %r3, [%rd1+8];
+; DISABLED-NEXT: ld.b32 %r4, [%rd1+12];
+; DISABLED-NEXT: ld.b32 %r5, [%rd1+16];
+; DISABLED-NEXT: add.rn.f32 %r6, %r1, %r2;
+; DISABLED-NEXT: add.rn.f32 %r7, %r3, %r4;
+; DISABLED-NEXT: add.rn.f32 %r8, %r6, %r7;
+; DISABLED-NEXT: add.rn.f32 %r9, %r8, %r5;
+; DISABLED-NEXT: st.param.b32 [func_retval0], %r9;
; DISABLED-NEXT: ret;
%p.1 = getelementptr float, ptr %p, i32 1
%p.2 = getelementptr float, ptr %p, i32 2
diff --git a/llvm/test/CodeGen/NVPTX/access-non-generic.ll b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
index a816f2e84b064..9edd4de017ee2 100644
--- a/llvm/test/CodeGen/NVPTX/access-non-generic.ll
+++ b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
@@ -23,10 +23,10 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
; load cast
%1 = load float, ptr addrspacecast (ptr addrspace(3) @scalar to ptr), align 4
call void @use(float %1)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [scalar];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [scalar];
; store cast
store float %v, ptr addrspacecast (ptr addrspace(3) @scalar to ptr), align 4
-; PTX: st.shared.b32 [scalar], %f{{[0-9]+}};
+; PTX: st.shared.b32 [scalar], %r{{[0-9]+}};
; use syncthreads to disable optimizations across components
call void @llvm.nvvm.barrier0()
; PTX: bar.sync 0;
@@ -35,20 +35,20 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
%2 = addrspacecast ptr addrspace(3) @scalar to ptr
%3 = load float, ptr %2, align 4
call void @use(float %3)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [scalar];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [scalar];
; cast; store
store float %v, ptr %2, align 4
-; PTX: st.shared.b32 [scalar], %f{{[0-9]+}};
+; PTX: st.shared.b32 [scalar], %r{{[0-9]+}};
call void @llvm.nvvm.barrier0()
; PTX: bar.sync 0;
; load gep cast
%4 = load float, ptr getelementptr inbounds ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i32 0, i32 5), align 4
call void @use(float %4)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [array+20];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [array+20];
; store gep cast
store float %v, ptr getelementptr inbounds ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i32 0, i32 5), align 4
-; PTX: st.shared.b32 [array+20], %f{{[0-9]+}};
+; PTX: st.shared.b32 [array+20], %r{{[0-9]+}};
call void @llvm.nvvm.barrier0()
; PTX: bar.sync 0;
@@ -56,10 +56,10 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
%5 = getelementptr inbounds [10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i32 0, i32 5
%6 = load float, ptr %5, align 4
call void @use(float %6)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [array+20];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [array+20];
; gep cast; store
store float %v, ptr %5, align 4
-; PTX: st.shared.b32 [array+20], %f{{[0-9]+}};
+; PTX: st.shared.b32 [array+20], %r{{[0-9]+}};
call void @llvm.nvvm.barrier0()
; PTX: bar.sync 0;
@@ -68,10 +68,10 @@ define void @ld_st_shared_f32(i32 %i, float %v) {
%8 = getelementptr inbounds [10 x float], ptr %7, i32 0, i32 %i
%9 = load float, ptr %8, align 4
call void @use(float %9)
-; PTX: ld.shared.b32 %f{{[0-9]+}}, [%{{(r|rl|rd)[0-9]+}}];
+; PTX: ld.shared.b32 %r{{[0-9]+}}, [%{{(r|rl|rd)[0-9]+}}];
; cast; gep; store
store float %v, ptr %8, align 4
-; PTX: st.shared.b32 [%{{(r|rl|rd)[0-9]+}}], %f{{[0-9]+}};
+; PTX: st.shared.b32 [%{{(r|rl|rd)[0-9]+}}], %r{{[0-9]+}};
call void @llvm.nvvm.barrier0()
; PTX: bar.sync 0;
diff --git a/llvm/test/CodeGen/NVPTX/aggregate-return.ll b/llvm/test/CodeGen/NVPTX/aggregate-return.ll
index 72c302433f081..1c8f019922e37 100644
--- a/llvm/test/CodeGen/NVPTX/aggregate-return.ll
+++ b/llvm/test/CodeGen/NVPTX/aggregate-return.ll
@@ -10,7 +10,7 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
; CHECK-LABEL: @test_v2f32
%call = tail call <2 x float> @barv(<2 x float> %input)
; CHECK: .param .align 8 .b8 retval0[8];
-; CHECK: ld.param.v2.b32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
+; CHECK: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
store <2 x float> %call, ptr %output, align 8
; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
ret void
@@ -21,10 +21,10 @@ define void @test_v3f32(<3 x float> %input, ptr %output) {
;
%call = tail call <3 x float> @barv3(<3 x float> %input)
; CHECK: .param .align 16 .b8 retval0[16];
-; CHECK-DAG: ld.param.v2.b32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
-; CHECK-DAG: ld.param.b32 [[E2:%f[0-9]+]], [retval0+8];
+; CHECK-DAG: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
+; CHECK-DAG: ld.param.b32 [[E2:%r[0-9]+]], [retval0+8];
; Make sure we don't load more values than than we need to.
-; CHECK-NOT: ld.param.b32 [[E3:%f[0-9]+]], [retval0+12];
+; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
store <3 x float> %call, ptr %output, align 8
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
; -- This is suboptimal. We should do st.v2.f32 instead
@@ -38,8 +38,8 @@ define void @test_a2f32([2 x float] %input, ptr %output) {
; CHECK-LABEL: @test_a2f32
%call = tail call [2 x float] @bara([2 x float] %input)
; CHECK: .param .align 4 .b8 retval0[8];
-; CHECK-DAG: ld.param.b32 [[ELEMA1:%f[0-9]+]], [retval0];
-; CHECK-DAG: ld.param.b32 [[ELEMA2:%f[0-9]+]], [retval0+4];
+; CHECK-DAG: ld.param.b32 [[ELEMA1:%r[0-9]+]], [retval0];
+; CHECK-DAG: ld.param.b32 [[ELEMA2:%r[0-9]+]], [retval0+4];
store [2 x float] %call, ptr %output, align 4
; CHECK: }
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMA1]]
@@ -52,8 +52,8 @@ define void @test_s2f32({float, float} %input, ptr %output) {
; CHECK-LABEL: @test_s2f32
%call = tail call {float, float} @bars({float, float} %input)
; CHECK: .param .align 4 .b8 retval0[8];
-; CHECK-DAG: ld.param.b32 [[ELEMS1:%f[0-9]+]], [retval0];
-; CHECK-DAG: ld.param.b32 [[ELEMS2:%f[0-9]+]], [retval0+4];
+; CHECK-DAG: ld.param.b32 [[ELEMS1:%r[0-9]+]], [retval0];
+; CHECK-DAG: ld.param.b32 [[ELEMS2:%r[0-9]+]], [retval0+4];
store {float, float} %call, ptr %output, align 4
; CHECK: }
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMS1]]
diff --git a/llvm/test/CodeGen/NVPTX/and-or-setcc.ll b/llvm/test/CodeGen/NVPTX/and-or-setcc.ll
index 53c741bd6cb2c..b7e6e8b85298a 100644
--- a/llvm/test/CodeGen/NVPTX/and-or-setcc.ll
+++ b/llvm/test/CodeGen/NVPTX/and-or-setcc.ll
@@ -8,15 +8,14 @@ define i1 @and_ord(float %a, float %b) {
; CHECK-LABEL: and_ord(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
-; CHECK-NEXT: .reg .b32 %r<2>;
-; CHECK-NEXT: .reg .b32 %f<3>;
+; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
-; CHECK-NEXT: ld.param.b32 %f1, [and_ord_param_0];
-; CHECK-NEXT: ld.param.b32 %f2, [and_ord_param_1];
-; CHECK-NEXT: setp.num.f32 %p1, %f1, %f2;
-; CHECK-NEXT: selp.b32 %r1, 1, 0, %p1;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ld.param.b32 %r1, [and_ord_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [and_ord_param_1];
+; CHECK-NEXT: setp.num.f32 %p1, %r1, %r2;
+; CHECK-NEXT: selp.b32 %r3, 1, 0, %p1;
+; CHECK-NEXT: st.param.b32 [func_retv...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I've checked how far back ptxas supports FP operations on .b32/.b64 registers, and it appears to work in CUDA versions as old as 9.1: https://godbolt.org/z/nbvPe57dc
101c2d0 to
916ac4f
Compare
|
@AlexMaclean Just a FYI. We've ran into an interesting issue in ptxas triggered by this change. It appears that uniform use of .b32 registers for integer and FP operations may trigger some sort of regression in ptxas. We observed that ons some kernels, where ptxas was able to eliminate local storage before the patch, it no longer does so. Unfortunately I can't share the kernel, but I will try to reduce the PTX to something that can be used to demonstrate the issue. |
Yikes! that is unfortunate. Is the regressions significant and widespread enough that we'd need to think about reverting this change? Or could it wait for a fix to ptxas in some future release of CUDA? Alternately, would it be possible to eliminate the local memory in LLVM via SROA? |
|
Most of the work on eliminating local storage falls on LLVM. In most of the kernels that still use local memory, ptxas usually can't do much about it either. This particular kernel happened to be the rare occasion where ptxas was previously able to remove a small 8-byte chunk of local data that LLVM kept around. AFAICT, such cases are pretty rare. So far I've see only one TU with a few kernels, and was able to work around it on the source level (helped LLVM to get rid of the alloca). |
This reverts commit 76c9bfe.
|
FYI, we saw a ~4x regression in a layernorm backward kernel in Triton that bisects to this PR: details here meta-pytorch/tritonbench#264 |
These classes are redundant, as the untyped "Int" classes can be used for all float operations. This change is intended to be as minimal as possible and leaves the many potential simplifications and refactors this exposes as future work.