@@ -4017,29 +4017,27 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
40174017}
40184018
40194019// This is similar to the default implementation in ExpandDYNAMIC_STACKALLOC,
4020- // except for stack growth direction(default: downwards, AMDGPU: upwards) and
4021- // applying the wave size scale to the increment amount.
4020+ // except:
4021+ // 1. stack growth direction(default: downwards, AMDGPU: upwards)
4022+ // 2. scale size where, scale = wave-reduction(alloca-size) * wave-size
40224023SDValue SITargetLowering::lowerDYNAMIC_STACKALLOCImpl(SDValue Op,
40234024 SelectionDAG &DAG) const {
40244025 const MachineFunction &MF = DAG.getMachineFunction();
40254026 const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
40264027
40274028 SDLoc dl(Op);
40284029 EVT VT = Op.getValueType();
4029- SDValue Tmp1 = Op;
4030- SDValue Tmp2 = Op.getValue(1);
4031- SDValue Tmp3 = Op.getOperand(2);
4032- SDValue Chain = Tmp1.getOperand(0);
4033-
4030+ SDValue Tmp = Op.getValue(1);
4031+ SDValue Chain = Op.getOperand(0);
40344032 Register SPReg = Info->getStackPtrOffsetReg();
40354033
40364034 // Chain the dynamic stack allocation so that it doesn't modify the stack
40374035 // pointer when other instructions are using the stack.
40384036 Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl);
40394037
4040- SDValue Size = Tmp2 .getOperand(1);
4038+ SDValue Size = Tmp .getOperand(1);
40414039 SDValue BaseAddr = DAG.getCopyFromReg(Chain, dl, SPReg, VT);
4042- Align Alignment = cast<ConstantSDNode>(Tmp3 )->getAlignValue();
4040+ Align Alignment = cast<ConstantSDNode>(Op.getOperand(2) )->getAlignValue();
40434041
40444042 const TargetFrameLowering *TFL = Subtarget->getFrameLowering();
40454043 assert(TFL->getStackGrowthDirection() == TargetFrameLowering::StackGrowsUp &&
@@ -4057,30 +4055,39 @@ SDValue SITargetLowering::lowerDYNAMIC_STACKALLOCImpl(SDValue Op,
40574055 DAG.getSignedConstant(-ScaledAlignment, dl, VT));
40584056 }
40594057
4060- SDValue ScaledSize = DAG.getNode(
4061- ISD::SHL, dl, VT, Size,
4062- DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32));
4063-
4064- SDValue NewSP = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value
4058+ assert(Size.getValueType() == MVT::i32 && "Size must be 32-bit");
4059+ SDValue NewSP;
4060+ if (isa<ConstantSDNode>(Op.getOperand(1))) {
4061+ SDValue ScaledSize = DAG.getNode(
4062+ ISD::SHL, dl, VT, Size,
4063+ DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32));
4064+ NewSP = DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value
4065+ } else {
4066+ // perform wave reduction to get the maximum size
4067+ SDValue WaveReduction =
4068+ DAG.getTargetConstant(Intrinsic::amdgcn_wave_reduce_umax, dl, MVT::i32);
4069+ Size = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::i32, WaveReduction,
4070+ Size, DAG.getConstant(0, dl, MVT::i32));
4071+ SDValue ScaledSize = DAG.getNode(
4072+ ISD::SHL, dl, VT, Size,
4073+ DAG.getConstant(Subtarget->getWavefrontSizeLog2(), dl, MVT::i32));
4074+ NewSP =
4075+ DAG.getNode(ISD::ADD, dl, VT, BaseAddr, ScaledSize); // Value in vgpr.
4076+ SDValue ReadFirstLaneID =
4077+ DAG.getTargetConstant(Intrinsic::amdgcn_readfirstlane, dl, MVT::i32);
4078+ NewSP = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::i32, ReadFirstLaneID,
4079+ NewSP);
4080+ }
40654081
40664082 Chain = DAG.getCopyToReg(Chain, dl, SPReg, NewSP); // Output chain
4067- Tmp2 = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
4083+ Tmp = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
40684084
4069- return DAG.getMergeValues({BaseAddr, Tmp2 }, dl);
4085+ return DAG.getMergeValues({BaseAddr, Tmp }, dl);
40704086}
40714087
40724088SDValue SITargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
40734089 SelectionDAG &DAG) const {
4074- // We only handle constant sizes here to allow non-entry block, static sized
4075- // allocas. A truly dynamic value is more difficult to support because we
4076- // don't know if the size value is uniform or not. If the size isn't uniform,
4077- // we would need to do a wave reduction to get the maximum size to know how
4078- // much to increment the uniform stack pointer.
4079- SDValue Size = Op.getOperand(1);
4080- if (isa<ConstantSDNode>(Size))
4081- return lowerDYNAMIC_STACKALLOCImpl(Op, DAG); // Use "generic" expansion.
4082-
4083- return AMDGPUTargetLowering::LowerDYNAMIC_STACKALLOC(Op, DAG);
4090+ return lowerDYNAMIC_STACKALLOCImpl(Op, DAG); // Use "generic" expansion.
40844091}
40854092
40864093SDValue SITargetLowering::LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const {
0 commit comments