From 8730bc39da1ea76306fd1c5f5c202bb09a1cf98f Mon Sep 17 00:00:00 2001 From: Alex Maclean Date: Wed, 19 Feb 2025 21:21:06 +0000 Subject: [PATCH] [NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) --- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 121 +++--------------- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h | 6 +- .../Target/NVPTX/NVPTXMachineFunctionInfo.h | 6 - .../Target/NVPTX/NVPTXReplaceImageHandles.cpp | 51 +++----- 4 files changed, 38 insertions(+), 146 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index c8e29c1da6ec4..6e5dd6b15900c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -149,67 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) { EmitToStreamer(*OutStreamer, Inst); } -// Handle symbol backtracking for targets that do not support image handles -bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI, - unsigned OpNo, MCOperand &MCOp) { - const MachineOperand &MO = MI->getOperand(OpNo); - const MCInstrDesc &MCID = MI->getDesc(); - - if (MCID.TSFlags & NVPTXII::IsTexFlag) { - // This is a texture fetch, so operand 4 is a texref and operand 5 is - // a samplerref - if (OpNo == 4 && MO.isImm()) { - lowerImageHandleSymbol(MO.getImm(), MCOp); - return true; - } - if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) { - lowerImageHandleSymbol(MO.getImm(), MCOp); - return true; - } - - return false; - } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { - unsigned VecSize = - 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); - - // For a surface load of vector size N, the Nth operand will be the surfref - if (OpNo == VecSize && MO.isImm()) { - lowerImageHandleSymbol(MO.getImm(), MCOp); - return true; - } - - return false; - } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { - // This is a surface store, so operand 0 is a surfref - if (OpNo == 0 && MO.isImm()) { - lowerImageHandleSymbol(MO.getImm(), MCOp); - return true; - } - - return false; - } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { - // This is a query, so operand 1 is a surfref/texref - if (OpNo == 1 && MO.isImm()) { - lowerImageHandleSymbol(MO.getImm(), MCOp); - return true; - } - - return false; - } - - return false; -} - -void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) { - // Ewwww - TargetMachine &TM = const_cast(MF->getTarget()); - NVPTXTargetMachine &nvTM = static_cast(TM); - const NVPTXMachineFunctionInfo *MFI = MF->getInfo(); - StringRef Sym = MFI->getImageHandleSymbol(Index); - StringRef SymName = nvTM.getStrPool().save(Sym); - MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName)); -} - void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { OutMI.setOpcode(MI->getOpcode()); // Special: Do not mangle symbol operand of CALL_PROTOTYPE @@ -220,67 +159,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) { return; } - for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) { - const MachineOperand &MO = MI->getOperand(i); - - MCOperand MCOp; - if (lowerImageHandleOperand(MI, i, MCOp)) { - OutMI.addOperand(MCOp); - continue; - } - - if (lowerOperand(MO, MCOp)) - OutMI.addOperand(MCOp); - } + for (const auto MO : MI->operands()) + OutMI.addOperand(lowerOperand(MO)); } -bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, - MCOperand &MCOp) { +MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) { switch (MO.getType()) { - default: llvm_unreachable("unknown operand type"); + default: + llvm_unreachable("unknown operand type"); case MachineOperand::MO_Register: - MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg())); - break; + return MCOperand::createReg(encodeVirtualRegister(MO.getReg())); case MachineOperand::MO_Immediate: - MCOp = MCOperand::createImm(MO.getImm()); - break; + return MCOperand::createImm(MO.getImm()); case MachineOperand::MO_MachineBasicBlock: - MCOp = MCOperand::createExpr(MCSymbolRefExpr::create( - MO.getMBB()->getSymbol(), OutContext)); - break; + return MCOperand::createExpr( + MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext)); case MachineOperand::MO_ExternalSymbol: - MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); - break; + return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName())); case MachineOperand::MO_GlobalAddress: - MCOp = GetSymbolRef(getSymbol(MO.getGlobal())); - break; + return GetSymbolRef(getSymbol(MO.getGlobal())); case MachineOperand::MO_FPImmediate: { const ConstantFP *Cnt = MO.getFPImm(); const APFloat &Val = Cnt->getValueAPF(); switch (Cnt->getType()->getTypeID()) { - default: report_fatal_error("Unsupported FP type"); break; - case Type::HalfTyID: - MCOp = MCOperand::createExpr( - NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); + default: + report_fatal_error("Unsupported FP type"); break; + case Type::HalfTyID: + return MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); case Type::BFloatTyID: - MCOp = MCOperand::createExpr( + return MCOperand::createExpr( NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); - break; case Type::FloatTyID: - MCOp = MCOperand::createExpr( - NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); - break; + return MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); case Type::DoubleTyID: - MCOp = MCOperand::createExpr( - NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); - break; + return MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext)); } break; } } - return true; } unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) { diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h index f7c3fda332eff..74daaa2fb7134 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h @@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { void emitInstruction(const MachineInstr *) override; void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI); - bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp); + MCOperand lowerOperand(const MachineOperand &MO); MCOperand GetSymbolRef(const MCSymbol *Symbol); unsigned encodeVirtualRegister(unsigned Reg); @@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O); void emitDemotedVars(const Function *, raw_ostream &); - bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo, - MCOperand &MCOp); - void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp); - bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const; // Used to control the need to emit .generic() in the initializer of diff --git a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h index 6670cb296f216..d9beab7ec42e1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h @@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo { return ImageHandleList.size()-1; } - /// Returns the symbol name at the given index. - StringRef getImageHandleSymbol(unsigned Idx) const { - assert(ImageHandleList.size() > Idx && "Bad index"); - return ImageHandleList[Idx]; - } - /// Check if the symbol has a mapping. Having a mapping means the handle is /// replaced with a reference bool checkImageHandleSymbol(StringRef Symbol) const { diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp index a3e3978cbbfe2..4d0694faa0c9a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp @@ -20,7 +20,6 @@ #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineRegisterInfo.h" -#include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass { private: bool processInstr(MachineInstr &MI); bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF); - bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF, - unsigned &Idx); }; -} +} // namespace char NVPTXReplaceImageHandles::ID = 0; @@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) { } return true; - } else if (MCID.TSFlags & NVPTXII::IsSuldMask) { + } + if (MCID.TSFlags & NVPTXII::IsSuldMask) { unsigned VecSize = - 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1); + 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - + 1); // For a surface load of vector size N, the Nth operand will be the surfref MachineOperand &SurfHandle = MI.getOperand(VecSize); @@ -1767,7 +1766,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) { MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode()))); return true; - } else if (MCID.TSFlags & NVPTXII::IsSustFlag) { + } + if (MCID.TSFlags & NVPTXII::IsSustFlag) { // This is a surface store, so operand 0 is a surfref MachineOperand &SurfHandle = MI.getOperand(0); @@ -1775,7 +1775,8 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) { MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode()))); return true; - } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { + } + if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) { // This is a query, so operand 1 is a surfref/texref MachineOperand &Handle = MI.getOperand(1); @@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) { bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op, MachineFunction &MF) { - unsigned Idx; - if (findIndexForHandle(Op, MF, Idx)) { - Op.ChangeToImmediate(Idx); - return true; - } - return false; -} - -bool NVPTXReplaceImageHandles:: -findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) { const MachineRegisterInfo &MRI = MF.getRegInfo(); NVPTXMachineFunctionInfo *MFI = MF.getInfo(); @@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) { case NVPTX::LD_i64_avar: { // The handle is a parameter value being loaded, replace with the // parameter symbol - const NVPTXTargetMachine &TM = - static_cast(MF.getTarget()); - if (TM.getDrvInterface() == NVPTX::CUDA) { + const auto &TM = static_cast(MF.getTarget()); + if (TM.getDrvInterface() == NVPTX::CUDA) // For CUDA, we preserve the param loads coming from function arguments return false; - } assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!"); StringRef Sym = TexHandleDef.getOperand(7).getSymbolName(); - std::string ParamBaseName = std::string(MF.getName()); - ParamBaseName += "_param_"; - assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference"); - unsigned Param = atoi(Sym.data()+ParamBaseName.size()); - std::string NewSym; - raw_string_ostream NewSymStr(NewSym); - NewSymStr << MF.getName() << "_param_" << Param; - InstrsToRemove.insert(&TexHandleDef); - Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str()); + Op.ChangeToES(Sym.data()); + MFI->getImageHandleSymbolIndex(Sym); return true; } case NVPTX::texsurf_handles: { @@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) { const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal(); assert(GV->hasName() && "Global sampler must be named!"); InstrsToRemove.insert(&TexHandleDef); - Idx = MFI->getImageHandleSymbolIndex(GV->getName()); + Op.ChangeToGA(GV, 0); return true; } case NVPTX::nvvm_move_i64: case TargetOpcode::COPY: { - bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx); - if (Res) { + bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF); + if (Res) InstrsToRemove.insert(&TexHandleDef); - } return Res; } default: