@@ -7280,24 +7280,25 @@ static SDValue lowerLaneOp(const SITargetLowering &TLI, SDNode *N,
72807280 return DAG.getBitcast(VT, UnrolledLaneOp);
72817281}
72827282
7283- static SDValue lowerSubgroupShuffle (const SITargetLowering &TLI, SDNode *N,
7283+ static SDValue lowerWaveShuffle (const SITargetLowering &TLI, SDNode *N,
72847284 SelectionDAG &DAG) {
72857285 EVT VT = N->getValueType(0);
72867286 unsigned ValSize = VT.getSizeInBits();
7287+ assert(ValSize == 32);
72877288 SDLoc SL(N);
72887289
72897290 SDValue Value = N->getOperand(1);
72907291 SDValue Index = N->getOperand(2);
72917292
72927293 // ds_bpermute requires index to be multiplied by 4
7293- SDValue ShiftAmount = DAG.getTargetConstant (2, SL, MVT::i32);
7294+ SDValue ShiftAmount = DAG.getShiftAmountConstant (2, MVT::i32, SL );
72947295 SDValue ShiftedIndex = DAG.getNode(ISD::SHL, SL, Index.getValueType(), Index,
72957296 ShiftAmount);
72967297
72977298 // Intrinsics will require i32 to operate on
7298- SDValue Value32 = Value;
7299- if ((ValSize != 32) || ( VT.isFloatingPoint() ))
7300- Value32 = DAG.getBitcast(MVT::i32, Value);
7299+ SDValue ValueI32 = Value;
7300+ if (VT.isFloatingPoint())
7301+ ValueI32 = DAG.getBitcast(MVT::i32, Value);
73017302
73027303 auto MakeIntrinsic = [&DAG, &SL](unsigned IID, MVT RetVT,
73037304 SmallVector<SDValue> IntrinArgs) -> SDValue {
@@ -7307,54 +7308,55 @@ static SDValue lowerSubgroupShuffle(const SITargetLowering &TLI, SDNode *N,
73077308 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, RetVT, Operands);
73087309 };
73097310
7311+ // If we can bpermute across the whole wave, then just do that
73107312 if (TLI.getSubtarget()->supportsWaveWideBPermute()) {
7311- // If we can bpermute across the whole wave, then just do that
73127313 SDValue BPermute = MakeIntrinsic(Intrinsic::amdgcn_ds_bpermute, MVT::i32,
7313- {ShiftedIndex, Value32 });
7314+ {ShiftedIndex, ValueI32 });
73147315 return DAG.getBitcast(VT, BPermute);
7315- } else {
7316- assert(TLI.getSubtarget()->isWave64());
7317-
7318- // Otherwise, we need to make use of whole wave mode
7319- SDValue PoisonVal = DAG.getPOISON(Value32->getValueType(0));
7320- SDValue PoisonIndex = DAG.getPOISON(ShiftedIndex->getValueType(0));
7321-
7322- // Set inactive lanes to poison
7323- SDValue WWMValue = MakeIntrinsic(Intrinsic::amdgcn_set_inactive, MVT::i32,
7324- {Value32, PoisonVal});
7325- SDValue WWMIndex = MakeIntrinsic(Intrinsic::amdgcn_set_inactive, MVT::i32,
7326- {ShiftedIndex, PoisonIndex});
7327-
7328- SDValue Swapped =
7329- MakeIntrinsic(Intrinsic::amdgcn_permlane64, MVT::i32, {WWMValue});
7330-
7331- // Get permutation of each half, then we'll select which one to use
7332- SDValue BPermSameHalf = MakeIntrinsic(Intrinsic::amdgcn_ds_bpermute,
7333- MVT::i32, {WWMIndex, WWMValue});
7334- SDValue BPermOtherHalf = MakeIntrinsic(Intrinsic::amdgcn_ds_bpermute,
7335- MVT::i32, {WWMIndex, Swapped});
7336- SDValue BPermOtherHalfWWM =
7337- MakeIntrinsic(Intrinsic::amdgcn_wwm, MVT::i32, {BPermOtherHalf});
7338-
7339- // Select which side to take the permute from
7340- SDValue ThreadIDMask = DAG.getTargetConstant(UINT32_MAX, SL, MVT::i32);
7341- SDValue ThreadIDLo =
7342- MakeIntrinsic(Intrinsic::amdgcn_mbcnt_lo, MVT::i32,
7343- {ThreadIDMask, DAG.getTargetConstant(0, SL, MVT::i32)});
7344- SDValue ThreadID = MakeIntrinsic(Intrinsic::amdgcn_mbcnt_hi, MVT::i32,
7345- {ThreadIDMask, ThreadIDLo});
7346-
7347- SDValue SameOrOtherHalf =
7348- DAG.getNode(ISD::AND, SL, MVT::i32,
7349- DAG.getNode(ISD::XOR, SL, MVT::i32, ThreadID, Index),
7350- DAG.getTargetConstant(32, SL, MVT::i32));
7351- SDValue UseSameHalf =
7352- DAG.getSetCC(SL, MVT::i1, SameOrOtherHalf,
7353- DAG.getConstant(0, SL, MVT::i32), ISD::SETEQ);
7354- SDValue Result = DAG.getSelect(SL, MVT::i32, UseSameHalf, BPermSameHalf,
7355- BPermOtherHalfWWM);
7356- return DAG.getBitcast(VT, Result);
73577316 }
7317+
7318+ assert(TLI.getSubtarget()->isWave64());
7319+
7320+ // Otherwise, we need to make use of whole wave mode
7321+ SDValue PoisonVal = DAG.getPOISON(ValueI32->getValueType(0));
7322+ SDValue PoisonIndex = DAG.getPOISON(ShiftedIndex->getValueType(0));
7323+
7324+ // Set inactive lanes to poison
7325+ SDValue WWMValue = MakeIntrinsic(Intrinsic::amdgcn_set_inactive, MVT::i32,
7326+ {ValueI32, PoisonVal});
7327+ SDValue WWMIndex = MakeIntrinsic(Intrinsic::amdgcn_set_inactive, MVT::i32,
7328+ {ShiftedIndex, PoisonIndex});
7329+
7330+ SDValue Swapped =
7331+ MakeIntrinsic(Intrinsic::amdgcn_permlane64, MVT::i32, {WWMValue});
7332+
7333+ // Get permutation of each half, then we'll select which one to use
7334+ SDValue BPermSameHalf = MakeIntrinsic(Intrinsic::amdgcn_ds_bpermute,
7335+ MVT::i32, {WWMIndex, WWMValue});
7336+ SDValue BPermOtherHalf = MakeIntrinsic(Intrinsic::amdgcn_ds_bpermute,
7337+ MVT::i32, {WWMIndex, Swapped});
7338+ SDValue BPermOtherHalfWWM =
7339+ MakeIntrinsic(Intrinsic::amdgcn_wwm, MVT::i32, {BPermOtherHalf});
7340+
7341+ // Select which side to take the permute from
7342+ SDValue ThreadIDMask = DAG.getAllOnesConstant(SL, MVT::i32);
7343+ // We can get away with only using mbcnt_lo here since we're only
7344+ // trying to detect which side of 32 each lane is on, and mbcnt_lo
7345+ // returns 32 for lanes 32-63.
7346+ SDValue ThreadID =
7347+ MakeIntrinsic(Intrinsic::amdgcn_mbcnt_lo, MVT::i32,
7348+ {ThreadIDMask, DAG.getTargetConstant(0, SL, MVT::i32)});
7349+
7350+ SDValue SameOrOtherHalf =
7351+ DAG.getNode(ISD::AND, SL, MVT::i32,
7352+ DAG.getNode(ISD::XOR, SL, MVT::i32, ThreadID, Index),
7353+ DAG.getTargetConstant(32, SL, MVT::i32));
7354+ SDValue UseSameHalf =
7355+ DAG.getSetCC(SL, MVT::i1, SameOrOtherHalf,
7356+ DAG.getConstant(0, SL, MVT::i32), ISD::SETEQ);
7357+ SDValue Result = DAG.getSelect(SL, MVT::i32, UseSameHalf, BPermSameHalf,
7358+ BPermOtherHalfWWM);
7359+ return DAG.getBitcast(VT, Result);
73587360}
73597361
73607362void SITargetLowering::ReplaceNodeResults(SDNode *N,
@@ -10264,8 +10266,8 @@ SDValue SITargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
1026410266 Poisons.push_back(DAG.getPOISON(ValTy));
1026510267 return DAG.getMergeValues(Poisons, SDLoc(Op));
1026610268 }
10267- case Intrinsic::amdgcn_subgroup_shuffle :
10268- return lowerSubgroupShuffle (*this, Op.getNode(), DAG);
10269+ case Intrinsic::amdgcn_wave_shuffle :
10270+ return lowerWaveShuffle (*this, Op.getNode(), DAG);
1026910271 default:
1027010272 if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
1027110273 AMDGPU::getImageDimIntrinsicInfo(IntrinsicID))
0 commit comments