Skip to content

Conversation

@AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-debuginfo

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 368.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145255.diff

7 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+5-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+17-17)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+5-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+443-478)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+1423-2096)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (+10-10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+5-13)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 9af6fb2cb198e..38912a7f09e30 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -215,15 +215,15 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
     // Encode the register class in the upper 4 bits
     // Must be kept in sync with NVPTXInstPrinter::printRegName
     unsigned Ret = 0;
-    if (RC == &NVPTX::Int1RegsRegClass) {
+    if (RC == &NVPTX::B1RegClass) {
       Ret = (1 << 28);
-    } else if (RC == &NVPTX::Int16RegsRegClass) {
+    } else if (RC == &NVPTX::B16RegClass) {
       Ret = (2 << 28);
-    } else if (RC == &NVPTX::Int32RegsRegClass) {
+    } else if (RC == &NVPTX::B32RegClass) {
       Ret = (3 << 28);
-    } else if (RC == &NVPTX::Int64RegsRegClass) {
+    } else if (RC == &NVPTX::B64RegClass) {
       Ret = (4 << 28);
-    } else if (RC == &NVPTX::Int128RegsRegClass) {
+    } else if (RC == &NVPTX::B128RegClass) {
       Ret = (7 << 28);
     } else {
       report_fatal_error("Bad register class");
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 492f4ab76fdbb..676654d6d33e7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -589,18 +589,18 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(Op, VT, IsOpSupported ? Action : NoI16x2Action);
   };
 
-  addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
-  addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
-  addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
-  addRegisterClass(MVT::f32, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::f64, &NVPTX::Int64RegsRegClass);
-  addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
-  addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
-  addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
-  addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
+  addRegisterClass(MVT::i1, &NVPTX::B1RegClass);
+  addRegisterClass(MVT::i16, &NVPTX::B16RegClass);
+  addRegisterClass(MVT::v2i16, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::v4i8, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::i32, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::i64, &NVPTX::B64RegClass);
+  addRegisterClass(MVT::f32, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::f64, &NVPTX::B64RegClass);
+  addRegisterClass(MVT::f16, &NVPTX::B16RegClass);
+  addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
+  addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
+  addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
 
   // Conversion to/from FP16/FP16x2 is always legal.
   setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -4866,22 +4866,22 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   if (Constraint.size() == 1) {
     switch (Constraint[0]) {
     case 'b':
-      return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B1RegClass);
     case 'c':
     case 'h':
-      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B16RegClass);
     case 'r':
     case 'f':
-      return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B32RegClass);
     case 'l':
     case 'N':
     case 'd':
-      return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B64RegClass);
     case 'q': {
       if (STI.getSmVersion() < 70)
         report_fatal_error("Inline asm with 128 bit operands is only "
                            "supported for sm_70 and higher!");
-      return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
+      return std::make_pair(0U, &NVPTX::B128RegClass);
     }
     }
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index f262a0fb66c25..bf84d1dca4ed5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -39,15 +39,15 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     report_fatal_error("Copy one register into another with a different width");
 
   unsigned Op;
-  if (DestRC == &NVPTX::Int1RegsRegClass) {
+  if (DestRC == &NVPTX::B1RegClass) {
     Op = NVPTX::IMOV1r;
-  } else if (DestRC == &NVPTX::Int16RegsRegClass) {
+  } else if (DestRC == &NVPTX::B16RegClass) {
     Op = NVPTX::MOV16r;
-  } else if (DestRC == &NVPTX::Int32RegsRegClass) {
+  } else if (DestRC == &NVPTX::B32RegClass) {
     Op = NVPTX::IMOV32r;
-  } else if (DestRC == &NVPTX::Int64RegsRegClass) {
+  } else if (DestRC == &NVPTX::B64RegClass) {
     Op = NVPTX::IMOV64r;
-  } else if (DestRC == &NVPTX::Int128RegsRegClass) {
+  } else if (DestRC == &NVPTX::B128RegClass) {
     Op = NVPTX::IMOV128r;
   } else {
     llvm_unreachable("Bad register copy");
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index bbe99dec5c445..5979054764647 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,29 +170,6 @@ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
 def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">;
 
-// Helper class to aid conversion between ValueType and a matching RegisterClass.
-
-class ValueToRegClass<ValueType T> {
-   string name = !cast<string>(T);
-   NVPTXRegClass ret = !cond(
-     !eq(name, "i1"): Int1Regs,
-     !eq(name, "i16"): Int16Regs,
-     !eq(name, "v2i16"): Int32Regs,
-     !eq(name, "i32"): Int32Regs,
-     !eq(name, "i64"): Int64Regs,
-     !eq(name, "f16"): Int16Regs,
-     !eq(name, "v2f16"): Int32Regs,
-     !eq(name, "bf16"): Int16Regs,
-     !eq(name, "v2bf16"): Int32Regs,
-     !eq(name, "f32"): Float32Regs,
-     !eq(name, "f64"): Float64Regs,
-     !eq(name, "ai32"): Int32ArgRegs,
-     !eq(name, "ai64"): Int64ArgRegs,
-     !eq(name, "af32"): Float32ArgRegs,
-     !eq(name, "if64"): Float64ArgRegs,
-    );
-}
-
 
 //===----------------------------------------------------------------------===//
 // Some Common Instruction Class Templates
@@ -219,18 +196,18 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
   int Size = ty.Size;
 }
 
-def I1RT     : RegTyInfo<i1,  Int1Regs,  i1imm,  imm>;
-def I16RT    : RegTyInfo<i16, Int16Regs, i16imm, imm>;
-def I32RT    : RegTyInfo<i32, Int32Regs, i32imm, imm>;
-def I64RT    : RegTyInfo<i64, Int64Regs, i64imm, imm>;
+def I1RT     : RegTyInfo<i1,  B1,  i1imm,  imm>;
+def I16RT    : RegTyInfo<i16, B16, i16imm, imm>;
+def I32RT    : RegTyInfo<i32, B32, i32imm, imm>;
+def I64RT    : RegTyInfo<i64, B64, i64imm, imm>;
 
-def F32RT    : RegTyInfo<f32, Float32Regs, f32imm, fpimm>;
-def F64RT    : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
-def F16RT    : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
-def BF16RT   : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
+def F32RT    : RegTyInfo<f32, B32, f32imm, fpimm>;
+def F64RT    : RegTyInfo<f64, B64, f64imm, fpimm>;
+def F16RT    : RegTyInfo<f16, B16, f16imm, fpimm, supports_imm = 0>;
+def BF16RT   : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;
 
-def F16X2RT  : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
-def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
+def F16X2RT  : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
 
 
 // This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -238,18 +215,18 @@ def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
 // construction of the asm string based on the provided dag arguments.
 // For example, the following asm-strings would be computed:
 //
-//   * BasicFlagsNVPTXInst<(outs Int32Regs:$dst),
-//                         (ins Int32Regs:$a, Int32Regs:$b), (ins),
+//   * BasicFlagsNVPTXInst<(outs B32:$dst),
+//                         (ins B32:$a, B32:$b), (ins),
 //                         "add.s32">;
 //         ---> "add.s32 \t$dst, $a, $b;"
 //
-//   * BasicFlagsNVPTXInst<(outs Int32Regs:$d),
-//                         (ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
+//   * BasicFlagsNVPTXInst<(outs B32:$d),
+//                         (ins B32:$a, B32:$b, Hexu32imm:$c),
 //                         (ins PrmtMode:$mode),
 //                         "prmt.b32${mode}">;
 //         ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
 //
-//   * BasicFlagsNVPTXInst<(outs Int64Regs:$state),
+//   * BasicFlagsNVPTXInst<(outs B64:$state),
 //                         (ins ADDR:$addr),
 //                         "mbarrier.arrive.b64">;
 //         ---> "mbarrier.arrive.b64 \t$state, [$addr];"
@@ -312,7 +289,7 @@ multiclass I3<string op_str, SDPatternOperator op_node, bit commutative> {
 }
 
 class I16x2<string OpcStr, SDNode OpNode> :
-  BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
+  BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
               OpcStr # "16x2",
               [(set v2i16:$dst, (OpNode v2i16:$a, v2i16:$b))]>,
               Requires<[hasPTX<80>, hasSM<90>]>;
@@ -334,73 +311,73 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
 multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
   if !not(NaN) then {
    def f64rr :
-     BasicNVPTXInst<(outs Float64Regs:$dst),
-               (ins Float64Regs:$a, Float64Regs:$b),
+     BasicNVPTXInst<(outs B64:$dst),
+               (ins B64:$a, B64:$b),
                OpcStr # ".f64",
                [(set f64:$dst, (OpNode f64:$a, f64:$b))]>;
    def f64ri :
-     BasicNVPTXInst<(outs Float64Regs:$dst),
-               (ins Float64Regs:$a, f64imm:$b),
+     BasicNVPTXInst<(outs B64:$dst),
+               (ins B64:$a, f64imm:$b),
                OpcStr # ".f64",
                [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>;
   }
    def f32rr_ftz :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, Float32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".ftz.f32",
                [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
                Requires<[doF32FTZ]>;
    def f32ri_ftz :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, f32imm:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, f32imm:$b),
                OpcStr # ".ftz.f32",
                [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[doF32FTZ]>;
    def f32rr :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, Float32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".f32",
                [(set f32:$dst, (OpNode f32:$a, f32:$b))]>;
    def f32ri :
-     BasicNVPTXInst<(outs Float32Regs:$dst),
-               (ins Float32Regs:$a, f32imm:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, f32imm:$b),
                OpcStr # ".f32",
                [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>;
 
    def f16rr_ftz :
-     BasicNVPTXInst<(outs Int16Regs:$dst),
-               (ins Int16Regs:$a, Int16Regs:$b),
+     BasicNVPTXInst<(outs B16:$dst),
+               (ins B16:$a, B16:$b),
                OpcStr # ".ftz.f16",
                [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, doF32FTZ]>;
    def f16rr :
-     BasicNVPTXInst<(outs Int16Regs:$dst),
-               (ins Int16Regs:$a, Int16Regs:$b),
+     BasicNVPTXInst<(outs B16:$dst),
+               (ins B16:$a, B16:$b),
                OpcStr # ".f16",
                [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
 
    def f16x2rr_ftz :
-     BasicNVPTXInst<(outs Int32Regs:$dst),
-               (ins Int32Regs:$a, Int32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".ftz.f16x2",
                [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
    def f16x2rr :
-     BasicNVPTXInst<(outs Int32Regs:$dst),
-               (ins Int32Regs:$a, Int32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".f16x2",
                [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
    def bf16rr :
-     BasicNVPTXInst<(outs Int16Regs:$dst),
-               (ins Int16Regs:$a, Int16Regs:$b),
+     BasicNVPTXInst<(outs B16:$dst),
+               (ins B16:$a, B16:$b),
                OpcStr # ".bf16",
                [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
    def bf16x2rr :
-     BasicNVPTXInst<(outs Int32Regs:$dst),
-               (ins Int32Regs:$a, Int32Regs:$b),
+     BasicNVPTXInst<(outs B32:$dst),
+               (ins B32:$a, B32:$b),
                OpcStr # ".bf16x2",
                [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
@@ -417,73 +394,73 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
 // just like the non ".rn" op, but prevents ptxas from creating FMAs.
 multiclass F3<string op_str, SDPatternOperator op_pat> {
   def f64rr :
-    BasicNVPTXInst<(outs Float64Regs:$dst),
-              (ins Float64Regs:$a, Float64Regs:$b),
+    BasicNVPTXInst<(outs B64:$dst),
+              (ins B64:$a, B64:$b),
               op_str # ".f64",
               [(set f64:$dst, (op_pat f64:$a, f64:$b))]>;
   def f64ri :
-    BasicNVPTXInst<(outs Float64Regs:$dst),
-              (ins Float64Regs:$a, f64imm:$b),
+    BasicNVPTXInst<(outs B64:$dst),
+              (ins B64:$a, f64imm:$b),
               op_str # ".f64",
               [(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
   def f32rr_ftz :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, Float32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".ftz.f32",
               [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
               Requires<[doF32FTZ]>;
   def f32ri_ftz :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, f32imm:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, f32imm:$b),
               op_str # ".ftz.f32",
               [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
               Requires<[doF32FTZ]>;
   def f32rr :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, Float32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".f32",
               [(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
   def f32ri :
-    BasicNVPTXInst<(outs Float32Regs:$dst),
-              (ins Float32Regs:$a, f32imm:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, f32imm:$b),
               op_str # ".f32",
               [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
 
   def f16rr_ftz :
-    BasicNVPTXInst<(outs Int16Regs:$dst),
-              (ins Int16Regs:$a, Int16Regs:$b),
+    BasicNVPTXInst<(outs B16:$dst),
+              (ins B16:$a, B16:$b),
               op_str # ".ftz.f16",
               [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
               Requires<[useFP16Math, doF32FTZ]>;
   def f16rr :
-    BasicNVPTXInst<(outs Int16Regs:$dst),
-              (ins Int16Regs:$a, Int16Regs:$b),
+    BasicNVPTXInst<(outs B16:$dst),
+              (ins B16:$a, B16:$b),
               op_str # ".f16",
               [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
               Requires<[useFP16Math]>;
 
   def f16x2rr_ftz :
-    BasicNVPTXInst<(outs Int32Regs:$dst),
-              (ins Int32Regs:$a, Int32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".ftz.f16x2",
               [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
               Requires<[useFP16Math, doF32FTZ]>;
   def f16x2rr :
-    BasicNVPTXInst<(outs Int32Regs:$dst),
-              (ins Int32Regs:$a, Int32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".f16x2",
               [(set v2f16:$dst, (op_pat v2f16:$a, v2f16:$b))]>,
               Requires<[useFP16Math]>;
   def bf16rr :
-    BasicNVPTXInst<(outs Int16Regs:$dst),
-              (ins Int16Regs:$a, Int16Regs:$b),
+    BasicNVPTXInst<(outs B16:$dst),
+              (ins B16:$a, B16:$b),
               op_str # ".bf16",
               [(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
               Requires<[hasBF16Math]>;
 
   def bf16x2rr :
-    BasicNVPTXInst<(outs Int32Regs:$dst),
-              (ins Int32Regs:$a, Int32Regs:$b),
+    BasicNVPTXInst<(outs B32:$dst),
+              (ins B32:$a, B32:$b),
               op_str # ".bf16x2",
               [(set v2bf16:$dst, (op_pat v2bf16:$a, v2bf16:$b))]>,
               Requires<[hasBF16Math]>;
@@ -504,40 +481,40 @@ multiclass F3_fma_component<string op_str, SDNode op_node> {
 // instructions: <OpcStr>.f64, <OpcStr>.f32, and <OpcStr>.ftz.f32 (flush
 // subnormal inputs and results to zero).
 multiclass F2<string OpcStr, SDNode OpNode> {
-   def f64 :     BasicNVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a),
+   def f64 :     BasicNVPTXInst<(outs B64:$dst), (ins B64:$a),
                            OpcStr # ".f64",
                            [(set f64:$dst, (OpNode f64:$a))]>;
-   def f32_ftz : BasicNVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a),
+   def f32_ftz : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".ftz.f32",
                            [(set f32:$dst, (OpNode f32:$a))]>,
                            Requires<[doF32FTZ]>;
-   def f32 :     BasicNVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a),
+   def f32 :     BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".f32",
                            [(set f32:$dst, (OpNode f32:$a))]>;
 }
 
 multiclass F2_Support_Half<string OpcStr, SDNode OpNode> {
-   def bf16 :      BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+   def bf16 :      BasicNVPTXInst<(outs B16:$dst), (ins B16:$a),
                            OpcStr # ".bf16",
                            [(set bf16:$dst, (OpNode bf16:$a))]>,
                            Requires<[hasSM<80>, hasPTX<70>]>;
-   def bf16x2 :    BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+   def bf16x2 :    BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".bf16x2",
                            [(set v2bf16:$dst, (OpNode v2bf16:$a))]>,
                            Requires<[hasSM<80>, hasPTX<70>]>;
-   def f16_ftz :   BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+   def f16_ftz :   BasicNVPTXInst<(outs B16:$dst), (ins B16:$a),
                            OpcStr # ".ftz.f16",
                            [(set f16:$dst, (OpNode f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
-   def f16x2_ftz : BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+   def f16x2_ftz : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".ftz.f16x2",
                            [(set v2f16:$dst, (OpNode v2f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
-   def f16 :       BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+   def f16 :       BasicNVPTXInst<(outs B16:$dst), (ins B16:$a),
                            OpcStr # ".f16",
                            [(set f16:$dst, (OpNode f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>]>;
-   def f16x2 :     BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+   def f16x2 :     BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
                            OpcStr # ".f16x2",
      ...
[truncated]

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/rename-regs branch from 1d6a90d to b2e61a1 Compare June 23, 2025 16:18
%5:int32regs = ProxyRegB32 killed %1
%6:int32regs = ProxyRegB32 killed %2
%7:int32regs = ProxyRegB32 killed %3
%4:b32 = ProxyRegB32 killed %0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this patch, but do we still need those proxyReg*. With the register classes boiled down to their exact PTX counterparts, we should not need them any more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B Can you explain further what the history/original intent was behind these ProxyRegs? I see they were added here https://reviews.llvm.org/D34708. But I don't fully follow why they were needed in the first place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original purpose was to enable library calls on the GPU:
https://reviews.llvm.org/D34708
49fac56

I've completely forgot that. Recently we've been sort of abusing those proxy registers to eliminate redundant register moves. e.g.

foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {
def: Pat<(ta (bitconvert (i32 UInt32Const:$a))),
(IMOVB32ri UInt32Const:$a)>;
foreach tb = [v2f16, v2bf16, v2i16, v4i8, i32] in {
if !ne(ta, tb) then {
def: Pat<(ta (bitconvert (tb Int32Regs:$a))),
(ProxyRegI32 Int32Regs:$a)>;

Your changes have largely rendered this use case obsolete.

However, I believe the original use case is still valid. So, to answer my own question -- yes, we still probably need those proxy registers.

@AlexMaclean AlexMaclean merged commit 7ce76e1 into llvm:main Jun 23, 2025
8 checks passed
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants