@@ -952,17 +952,22 @@ getAssignFnsForCC(CallingConv::ID CC, const SITargetLowering &TLI) {
952952}
953953
954954static unsigned getCallOpcode (const MachineFunction &CallerF, bool IsIndirect,
955- bool IsTailCall, bool isWave32,
956- CallingConv::ID CC) {
955+ bool IsTailCall, bool IsWave32,
956+ CallingConv::ID CC,
957+ bool IsDynamicVGPRChainCall = false ) {
957958 // For calls to amdgpu_cs_chain functions, the address is known to be uniform.
958959 assert ((AMDGPU::isChainCC (CC) || !IsIndirect || !IsTailCall) &&
959960 " Indirect calls can't be tail calls, "
960961 " because the address can be divergent" );
961962 if (!IsTailCall)
962963 return AMDGPU::G_SI_CALL;
963964
964- if (AMDGPU::isChainCC (CC))
965- return isWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
965+ if (AMDGPU::isChainCC (CC)) {
966+ if (IsDynamicVGPRChainCall)
967+ return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32_DVGPR
968+ : AMDGPU::SI_CS_CHAIN_TC_W64_DVGPR;
969+ return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
970+ }
966971
967972 return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
968973 AMDGPU::SI_TCRETURN;
@@ -971,7 +976,8 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
971976// Add operands to call instruction to track the callee.
972977static bool addCallTargetOperands (MachineInstrBuilder &CallInst,
973978 MachineIRBuilder &MIRBuilder,
974- AMDGPUCallLowering::CallLoweringInfo &Info) {
979+ AMDGPUCallLowering::CallLoweringInfo &Info,
980+ bool IsDynamicVGPRChainCall = false ) {
975981 if (Info.Callee .isReg ()) {
976982 CallInst.addReg (Info.Callee .getReg ());
977983 CallInst.addImm (0 );
@@ -982,7 +988,12 @@ static bool addCallTargetOperands(MachineInstrBuilder &CallInst,
982988 auto Ptr = MIRBuilder.buildGlobalValue (
983989 LLT::pointer (GV->getAddressSpace (), 64 ), GV);
984990 CallInst.addReg (Ptr.getReg (0 ));
985- CallInst.add (Info.Callee );
991+
992+ if (IsDynamicVGPRChainCall) {
993+ // DynamicVGPR chain calls are always indirect.
994+ CallInst.addImm (0 );
995+ } else
996+ CallInst.add (Info.Callee );
986997 } else
987998 return false ;
988999
@@ -1176,6 +1187,18 @@ void AMDGPUCallLowering::handleImplicitCallArguments(
11761187 }
11771188}
11781189
1190+ namespace {
1191+ // Chain calls have special arguments that we need to handle. These have the
1192+ // same index as they do in the llvm.amdgcn.cs.chain intrinsic.
1193+ enum ChainCallArgIdx {
1194+ Exec = 1 ,
1195+ Flags = 4 ,
1196+ NumVGPRs = 5 ,
1197+ FallbackExec = 6 ,
1198+ FallbackCallee = 7 ,
1199+ };
1200+ } // anonymous namespace
1201+
11791202bool AMDGPUCallLowering::lowerTailCall (
11801203 MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info,
11811204 SmallVectorImpl<ArgInfo> &OutArgs) const {
@@ -1184,6 +1207,8 @@ bool AMDGPUCallLowering::lowerTailCall(
11841207 SIMachineFunctionInfo *FuncInfo = MF.getInfo <SIMachineFunctionInfo>();
11851208 const Function &F = MF.getFunction ();
11861209 MachineRegisterInfo &MRI = MF.getRegInfo ();
1210+ const SIInstrInfo *TII = ST.getInstrInfo ();
1211+ const SIRegisterInfo *TRI = ST.getRegisterInfo ();
11871212 const SITargetLowering &TLI = *getTLI<SITargetLowering>();
11881213
11891214 // True when we're tail calling, but without -tailcallopt.
@@ -1199,34 +1224,79 @@ bool AMDGPUCallLowering::lowerTailCall(
11991224 if (!IsSibCall)
12001225 CallSeqStart = MIRBuilder.buildInstr (AMDGPU::ADJCALLSTACKUP);
12011226
1202- unsigned Opc =
1203- getCallOpcode (MF, Info.Callee .isReg (), true , ST.isWave32 (), CalleeCC);
1227+ bool IsChainCall = AMDGPU::isChainCC (Info.CallConv );
1228+ bool IsDynamicVGPRChainCall = false ;
1229+
1230+ if (IsChainCall) {
1231+ ArgInfo FlagsArg = Info.OrigArgs [ChainCallArgIdx::Flags];
1232+ const APInt &FlagsValue = cast<ConstantInt>(FlagsArg.OrigValue )->getValue ();
1233+ if (FlagsValue.isZero ()) {
1234+ if (Info.OrigArgs .size () != 5 ) {
1235+ LLVM_DEBUG (dbgs () << " No additional args allowed if flags == 0\n " );
1236+ return false ;
1237+ }
1238+ } else if (FlagsValue.isOneBitSet (0 )) {
1239+ IsDynamicVGPRChainCall = true ;
1240+
1241+ if (Info.OrigArgs .size () != 8 ) {
1242+ LLVM_DEBUG (dbgs () << " Expected 3 additional args" );
1243+ return false ;
1244+ }
1245+
1246+ // On GFX12, we can only change the VGPR allocation for wave32.
1247+ if (!ST.isWave32 ()) {
1248+ F.getContext ().diagnose (DiagnosticInfoUnsupported (
1249+ F, " Dynamic VGPR mode is only supported for wave32\n " ));
1250+ return false ;
1251+ }
1252+
1253+ ArgInfo FallbackExecArg = Info.OrigArgs [ChainCallArgIdx::FallbackExec];
1254+ assert (FallbackExecArg.Regs .size () == 1 &&
1255+ " Expected single register for fallback EXEC" );
1256+ if (!FallbackExecArg.Ty ->isIntegerTy (ST.getWavefrontSize ())) {
1257+ LLVM_DEBUG (dbgs () << " Bad type for fallback EXEC" );
1258+ return false ;
1259+ }
1260+ }
1261+ }
1262+
1263+ unsigned Opc = getCallOpcode (MF, Info.Callee .isReg (), /* IsTailCall*/ true ,
1264+ ST.isWave32 (), CalleeCC, IsDynamicVGPRChainCall);
12041265 auto MIB = MIRBuilder.buildInstrNoInsert (Opc);
1205- if (!addCallTargetOperands (MIB, MIRBuilder, Info))
1266+ if (!addCallTargetOperands (MIB, MIRBuilder, Info, IsDynamicVGPRChainCall ))
12061267 return false ;
12071268
12081269 // Byte offset for the tail call. When we are sibcalling, this will always
12091270 // be 0.
12101271 MIB.addImm (0 );
12111272
1212- // If this is a chain call, we need to pass in the EXEC mask.
1213- const SIRegisterInfo *TRI = ST.getRegisterInfo ();
1214- if (AMDGPU::isChainCC (Info.CallConv )) {
1215- ArgInfo ExecArg = Info.OrigArgs [1 ];
1273+ // If this is a chain call, we need to pass in the EXEC mask as well as any
1274+ // other special args.
1275+ if (IsChainCall) {
1276+ auto AddRegOrImm = [&](const ArgInfo &Arg) {
1277+ if (auto CI = dyn_cast<ConstantInt>(Arg.OrigValue )) {
1278+ MIB.addImm (CI->getSExtValue ());
1279+ } else {
1280+ MIB.addReg (Arg.Regs [0 ]);
1281+ unsigned Idx = MIB->getNumOperands () - 1 ;
1282+ MIB->getOperand (Idx).setReg (constrainOperandRegClass (
1283+ MF, *TRI, MRI, *TII, *ST.getRegBankInfo (), *MIB, MIB->getDesc (),
1284+ MIB->getOperand (Idx), Idx));
1285+ }
1286+ };
1287+
1288+ ArgInfo ExecArg = Info.OrigArgs [ChainCallArgIdx::Exec];
12161289 assert (ExecArg.Regs .size () == 1 && " Too many regs for EXEC" );
12171290
1218- if (!ExecArg.Ty ->isIntegerTy (ST.getWavefrontSize ()))
1291+ if (!ExecArg.Ty ->isIntegerTy (ST.getWavefrontSize ())) {
1292+ LLVM_DEBUG (dbgs () << " Bad type for EXEC" );
12191293 return false ;
1220-
1221- if (const auto *CI = dyn_cast<ConstantInt>(ExecArg.OrigValue )) {
1222- MIB.addImm (CI->getSExtValue ());
1223- } else {
1224- MIB.addReg (ExecArg.Regs [0 ]);
1225- unsigned Idx = MIB->getNumOperands () - 1 ;
1226- MIB->getOperand (Idx).setReg (constrainOperandRegClass (
1227- MF, *TRI, MRI, *ST.getInstrInfo (), *ST.getRegBankInfo (), *MIB,
1228- MIB->getDesc (), MIB->getOperand (Idx), Idx));
12291294 }
1295+
1296+ AddRegOrImm (ExecArg);
1297+ if (IsDynamicVGPRChainCall)
1298+ std::for_each (Info.OrigArgs .begin () + ChainCallArgIdx::NumVGPRs,
1299+ Info.OrigArgs .end (), AddRegOrImm);
12301300 }
12311301
12321302 // Tell the call which registers are clobbered.
@@ -1328,9 +1398,9 @@ bool AMDGPUCallLowering::lowerTailCall(
13281398 // FIXME: We should define regbankselectable call instructions to handle
13291399 // divergent call targets.
13301400 if (MIB->getOperand (0 ).isReg ()) {
1331- MIB->getOperand (0 ).setReg (constrainOperandRegClass (
1332- MF, *TRI, MRI, *ST. getInstrInfo () , *ST.getRegBankInfo (), *MIB ,
1333- MIB->getDesc (), MIB->getOperand (0 ), 0 ));
1401+ MIB->getOperand (0 ).setReg (
1402+ constrainOperandRegClass ( MF, *TRI, MRI, *TII , *ST.getRegBankInfo (),
1403+ *MIB, MIB->getDesc (), MIB->getOperand (0 ), 0 ));
13341404 }
13351405
13361406 MF.getFrameInfo ().setHasTailCall ();
@@ -1344,11 +1414,6 @@ bool AMDGPUCallLowering::lowerChainCall(MachineIRBuilder &MIRBuilder,
13441414 ArgInfo Callee = Info.OrigArgs [0 ];
13451415 ArgInfo SGPRArgs = Info.OrigArgs [2 ];
13461416 ArgInfo VGPRArgs = Info.OrigArgs [3 ];
1347- ArgInfo Flags = Info.OrigArgs [4 ];
1348-
1349- assert (cast<ConstantInt>(Flags.OrigValue )->isZero () &&
1350- " Non-zero flags aren't supported yet." );
1351- assert (Info.OrigArgs .size () == 5 && " Additional args aren't supported yet." );
13521417
13531418 MachineFunction &MF = MIRBuilder.getMF ();
13541419 const Function &F = MF.getFunction ();
0 commit comments