@@ -3593,6 +3593,8 @@ bool SITargetLowering::isEligibleForTailCallOptimization(
35933593 SmallVector<CCValAssign, 16> ArgLocs;
35943594 CCState CCInfo(CalleeCC, IsVarArg, MF, ArgLocs, Ctx);
35953595
3596+ // FIXME: We are not allocating special input registers, so we will be
3597+ // deciding based on incorrect register assignments.
35963598 CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, IsVarArg));
35973599
35983600 const SIMachineFunctionInfo *FuncInfo = MF.getInfo<SIMachineFunctionInfo>();
@@ -3602,6 +3604,21 @@ bool SITargetLowering::isEligibleForTailCallOptimization(
36023604 if (CCInfo.getStackSize() > FuncInfo->getBytesInStackArgArea())
36033605 return false;
36043606
3607+ for (const auto &[CCVA, ArgVal] : zip_equal(ArgLocs, OutVals)) {
3608+ // FIXME: What about inreg arguments that end up passed in memory?
3609+ if (!CCVA.isRegLoc())
3610+ continue;
3611+
3612+ // If we are passing an argument in an SGPR, and the value is divergent,
3613+ // this call requires a waterfall loop.
3614+ if (ArgVal->isDivergent() && TRI->isSGPRPhysReg(CCVA.getLocReg())) {
3615+ LLVM_DEBUG(
3616+ dbgs() << "Cannot tail call due to divergent outgoing argument in "
3617+ << printReg(CCVA.getLocReg(), TRI) << '\n');
3618+ return false;
3619+ }
3620+ }
3621+
36053622 const MachineRegisterInfo &MRI = MF.getRegInfo();
36063623 return parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals);
36073624}
@@ -3734,6 +3751,7 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
37343751 // arguments to begin at SP+0. Completely unused for non-tail calls.
37353752 int32_t FPDiff = 0;
37363753 MachineFrameInfo &MFI = MF.getFrameInfo();
3754+ auto *TRI = static_cast<const SIRegisterInfo *>(Subtarget->getRegisterInfo());
37373755
37383756 // Adjust the stack pointer for the new arguments...
37393757 // These operations are automatically eliminated by the prolog/epilog pass
@@ -3756,6 +3774,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
37563774 }
37573775 }
37583776
3777+ const unsigned NumSpecialInputs = RegsToPass.size();
3778+
37593779 MVT PtrVT = MVT::i32;
37603780
37613781 // Walk the register/memloc assignments, inserting copies/loads.
@@ -3857,16 +3877,40 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
38573877 if (!MemOpChains.empty())
38583878 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
38593879
3880+ SDValue ReadFirstLaneID =
3881+ DAG.getTargetConstant(Intrinsic::amdgcn_readfirstlane, DL, MVT::i32);
3882+
3883+ SDValue TokenGlue;
3884+ if (CLI.ConvergenceControlToken) {
3885+ TokenGlue = DAG.getNode(ISD::CONVERGENCECTRL_GLUE, DL, MVT::Glue,
3886+ CLI.ConvergenceControlToken);
3887+ }
3888+
38603889 // Build a sequence of copy-to-reg nodes chained together with token chain
38613890 // and flag operands which copy the outgoing args into the appropriate regs.
38623891 SDValue InGlue;
3863- for (auto &RegToPass : RegsToPass) {
3864- Chain = DAG.getCopyToReg(Chain, DL, RegToPass.first,
3865- RegToPass.second, InGlue);
3892+
3893+ unsigned ArgIdx = 0;
3894+ for (auto [Reg, Val] : RegsToPass) {
3895+ if (ArgIdx++ >= NumSpecialInputs && !Val->isDivergent() &&
3896+ TRI->isSGPRPhysReg(Reg)) {
3897+ // Speculatively insert a readfirstlane in case this is a uniform value in
3898+ // a VGPR.
3899+ //
3900+ // FIXME: We need to execute this in a waterfall loop if it is a divergent
3901+ // value, so let that continue to produce invalid code.
3902+
3903+ SmallVector<SDValue, 3> ReadfirstlaneArgs({ReadFirstLaneID, Val});
3904+ if (TokenGlue)
3905+ ReadfirstlaneArgs.push_back(TokenGlue);
3906+ Val = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Val.getValueType(),
3907+ ReadfirstlaneArgs);
3908+ }
3909+
3910+ Chain = DAG.getCopyToReg(Chain, DL, Reg, Val, InGlue);
38663911 InGlue = Chain.getValue(1);
38673912 }
38683913
3869-
38703914 // We don't usually want to end the call-sequence here because we would tidy
38713915 // the frame up *after* the call, however in the ABI-changing tail-call case
38723916 // we've carefully laid out the parameters so that when sp is reset they'll be
@@ -3896,12 +3940,8 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
38963940 DAG.getTargetConstant(Intrinsic::amdgcn_readfirstlane, DL, MVT::i32);
38973941
38983942 SmallVector<SDValue, 3> ReadfirstlaneArgs({ReadFirstLaneID, Callee});
3899- if (CLI.ConvergenceControlToken) {
3900- SDValue TokenGlue = DAG.getNode(ISD::CONVERGENCECTRL_GLUE, {},
3901- MVT::Glue, CLI.ConvergenceControlToken);
3943+ if (TokenGlue)
39023944 ReadfirstlaneArgs.push_back(TokenGlue); // Wire up convergence token.
3903- }
3904-
39053945 Callee = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Callee.getValueType(),
39063946 ReadfirstlaneArgs);
39073947 }
@@ -3928,7 +3968,6 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
39283968 }
39293969
39303970 // Add a register mask operand representing the call-preserved registers.
3931- auto *TRI = static_cast<const SIRegisterInfo *>(Subtarget->getRegisterInfo());
39323971 const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
39333972 assert(Mask && "Missing call preserved mask for calling convention");
39343973 Ops.push_back(DAG.getRegisterMask(Mask));
0 commit comments