@@ -1016,6 +1016,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
10161016 ISD::SCALAR_TO_VECTOR,
10171017 ISD::ZERO_EXTEND,
10181018 ISD::SIGN_EXTEND_INREG,
1019+ ISD::ANY_EXTEND,
10191020 ISD::EXTRACT_VECTOR_ELT,
10201021 ISD::INSERT_VECTOR_ELT,
10211022 ISD::FCOPYSIGN});
@@ -13429,6 +13430,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1342913430 return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
1343013431 }
1343113432
13433+ case ISD::ANY_EXTEND:
1343213434 case ISD::SIGN_EXTEND:
1343313435 case ISD::ZERO_EXTEND:
1343413436 case ISD::SIGN_EXTEND_INREG: {
@@ -14212,10 +14214,11 @@ SDValue SITargetLowering::performXorCombine(SDNode *N,
1421214214 return SDValue();
1421314215}
1421414216
14215- SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
14216- DAGCombinerInfo &DCI) const {
14217+ SDValue
14218+ SITargetLowering::performZeroOrAnyExtendCombine(SDNode *N,
14219+ DAGCombinerInfo &DCI) const {
1421714220 if (!Subtarget->has16BitInsts() ||
14218- DCI.getDAGCombineLevel() < AfterLegalizeDAG )
14221+ DCI.getDAGCombineLevel() < AfterLegalizeTypes )
1421914222 return SDValue();
1422014223
1422114224 EVT VT = N->getValueType(0);
@@ -14226,7 +14229,44 @@ SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
1422614229 if (Src.getValueType() != MVT::i16)
1422714230 return SDValue();
1422814231
14229- return SDValue();
14232+ if (!Src->hasOneUse())
14233+ return SDValue();
14234+
14235+ // TODO: We bail out below if SrcOffset is not in the first dword (>= 4). It's
14236+ // possible we're missing out on some combine opportunities, but we'd need to
14237+ // weigh the cost of extracting the byte from the upper dwords.
14238+
14239+ std::optional<ByteProvider<SDValue>> BP0 =
14240+ calculateByteProvider(SDValue(N, 0), 0, 0, 0);
14241+ if (!BP0 || BP0->SrcOffset >= 4 || !BP0->Src)
14242+ return SDValue();
14243+ SDValue V0 = *BP0->Src;
14244+
14245+ std::optional<ByteProvider<SDValue>> BP1 =
14246+ calculateByteProvider(SDValue(N, 0), 1, 0, 1);
14247+ if (!BP1 || BP1->SrcOffset >= 4 || !BP1->Src)
14248+ return SDValue();
14249+
14250+ SDValue V1 = *BP1->Src;
14251+
14252+ if (V0 == V1)
14253+ return SDValue();
14254+
14255+ SelectionDAG &DAG = DCI.DAG;
14256+ SDLoc DL(N);
14257+ uint32_t PermMask = 0x0c0c0c0c;
14258+ if (V0) {
14259+ V0 = DAG.getBitcastedAnyExtOrTrunc(V0, DL, MVT::i32);
14260+ PermMask = (PermMask & ~0xFF) | (BP0->SrcOffset + 4);
14261+ }
14262+
14263+ if (V1) {
14264+ V1 = DAG.getBitcastedAnyExtOrTrunc(V1, DL, MVT::i32);
14265+ PermMask = (PermMask & ~(0xFF << 8)) | (BP1->SrcOffset << 8);
14266+ }
14267+
14268+ return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, V0, V1,
14269+ DAG.getConstant(PermMask, DL, MVT::i32));
1423014270}
1423114271
1423214272SDValue
@@ -16861,8 +16901,9 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1686116901 }
1686216902 case ISD::XOR:
1686316903 return performXorCombine(N, DCI);
16904+ case ISD::ANY_EXTEND:
1686416905 case ISD::ZERO_EXTEND:
16865- return performZeroExtendCombine (N, DCI);
16906+ return performZeroOrAnyExtendCombine (N, DCI);
1686616907 case ISD::SIGN_EXTEND_INREG:
1686716908 return performSignExtendInRegCombine(N, DCI);
1686816909 case AMDGPUISD::FP_CLASS:
0 commit comments