@@ -4017,29 +4017,26 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
40174017}
40184018
40194019// This is similar to the default implementation in ExpandDYNAMIC_STACKALLOC,
4020- // except for stack growth direction(default: downwards, AMDGPU: upwards) and
4021- // applying the wave size scale to the increment amount.
4020+ // except for:
4021+ // 1. stack growth direction(default: downwards, AMDGPU: upwards), and
4022+ // 2. scale size where, scale = wave-reduction(alloca-size) * wave-size
40224023SDValue SITargetLowering::lowerDYNAMIC_STACKALLOCImpl(SDValue Op,
40234024 SelectionDAG &DAG) const {
40244025 const MachineFunction &MF = DAG.getMachineFunction();
40254026 const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
40264027
40274028 SDLoc dl(Op);
40284029 EVT VT = Op.getValueType();
4029- SDValue Tmp1 = Op;
4030- SDValue Tmp2 = Op.getValue(1);
4031- SDValue Tmp3 = Op.getOperand(2);
4032- SDValue Chain = Tmp1.getOperand(0);
4033-
4030+ SDValue Chain = Op.getOperand(0);
40344031 Register SPReg = Info->getStackPtrOffsetReg();
40354032
40364033 // Chain the dynamic stack allocation so that it doesn't modify the stack
40374034 // pointer when other instructions are using the stack.
40384035 Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl);
40394036
4040- SDValue Size = Tmp2 .getOperand(1);
4037+ SDValue Size = Op.getValue(1) .getOperand(1);
40414038 SDValue BaseAddr = DAG.getCopyFromReg(Chain, dl, SPReg, VT);
4042- Align Alignment = cast<ConstantSDNode>(Tmp3 )->getAlignValue();
4039+ Align Alignment = cast<ConstantSDNode>(Op.getOperand(2) )->getAlignValue();
40434040
40444041 const TargetFrameLowering *TFL = Subtarget->getFrameLowering();
40454042 assert(TFL->getStackGrowthDirection() == TargetFrameLowering::StackGrowsUp &&
@@ -4057,30 +4054,41 @@ SDValue SITargetLowering::lowerDYNAMIC_STACKALLOCImpl(SDValue Op,
40574054 DAG.getSignedConstant(-ScaledAlignment, dl, VT));
40584055 }
40594056
4060- SDValue ScaledSize = DAG.getNode(
4061- ISD::SHL, dl, VT, Size,
4062- DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32));
4063-
4064- SDValue NewSP = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value
4057+ assert(Size.getValueType() == MVT::i32 && "Size must be 32-bit");
4058+ SDValue NewSP;
4059+ if (isa<ConstantSDNode>(Op.getOperand(1))) {
4060+ // for constant sized alloca, scale alloca size by wave-size
4061+ SDValue ScaledSize = DAG.getNode(
4062+ ISD::SHL, dl, VT, Size,
4063+ DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32));
4064+ NewSP = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value
4065+ } else {
4066+ // for dynamic sized alloca, perform wave-wide reduction to get max of
4067+ // alloca size(divergent) and then scale it by wave-size
4068+ SDValue WaveReduction =
4069+ DAG.getTargetConstant(Intrinsic::amdgcn_wave_reduce_umax, dl, MVT::i32);
4070+ Size = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::i32, WaveReduction,
4071+ Size, DAG.getConstant(0, dl, MVT::i32));
4072+ SDValue ScaledSize = DAG.getNode(
4073+ ISD::SHL, dl, VT, Size,
4074+ DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32));
4075+ NewSP =
4076+ DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value in vgpr.
4077+ SDValue ReadFirstLaneID =
4078+ DAG.getTargetConstant(Intrinsic::amdgcn_readfirstlane, dl, MVT::i32);
4079+ NewSP = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::i32, ReadFirstLaneID,
4080+ NewSP);
4081+ }
40654082
40664083 Chain = DAG.getCopyToReg(Chain, dl, SPReg, NewSP); // Output chain
4067- Tmp2 = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
4084+ SDValue CallSeqEnd = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
40684085
4069- return DAG.getMergeValues({BaseAddr, Tmp2 }, dl);
4086+ return DAG.getMergeValues({BaseAddr, CallSeqEnd }, dl);
40704087}
40714088
40724089SDValue SITargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
40734090 SelectionDAG &DAG) const {
4074- // We only handle constant sizes here to allow non-entry block, static sized
4075- // allocas. A truly dynamic value is more difficult to support because we
4076- // don't know if the size value is uniform or not. If the size isn't uniform,
4077- // we would need to do a wave reduction to get the maximum size to know how
4078- // much to increment the uniform stack pointer.
4079- SDValue Size = Op.getOperand(1);
4080- if (isa<ConstantSDNode>(Size))
4081- return lowerDYNAMIC_STACKALLOCImpl(Op, DAG); // Use "generic" expansion.
4082-
4083- return AMDGPUTargetLowering::LowerDYNAMIC_STACKALLOC(Op, DAG);
4091+ return lowerDYNAMIC_STACKALLOCImpl(Op, DAG); // Use "generic" expansion.
40844092}
40854093
40864094SDValue SITargetLowering::LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const {
0 commit comments