@@ -53347,6 +53347,80 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
5334753347 return SDValue();
5334853348}
5334953349
53350+ // Look for a RMW operation that only touches one bit of a larger than legal
53351+ // type and fold it to a BTC/BTR/BTS pattern acting on a single i32 sub value.
53352+ static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
53353+ SelectionDAG &DAG,
53354+ const X86Subtarget &Subtarget) {
53355+ using namespace SDPatternMatch;
53356+
53357+ // Only handle normal stores and its chain was a matching normal load.
53358+ auto *Ld = dyn_cast<LoadSDNode>(St->getChain());
53359+ if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld ||
53360+ !ISD::isNormalLoad(Ld) || !Ld->isSimple() ||
53361+ Ld->getBasePtr() != St->getBasePtr() ||
53362+ Ld->getOffset() != St->getOffset())
53363+ return SDValue();
53364+
53365+ SDValue LoadVal(Ld, 0);
53366+ SDValue StoredVal = St->getValue();
53367+ EVT VT = StoredVal.getValueType();
53368+
53369+ // Only narrow larger than legal scalar integers.
53370+ if (!VT.isScalarInteger() ||
53371+ VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32))
53372+ return SDValue();
53373+
53374+ // BTR: X & ~(1 << ShAmt)
53375+ // BTS: X | (1 << ShAmt)
53376+ // BTC: X ^ (1 << ShAmt)
53377+ SDValue ShAmt;
53378+ if (!StoredVal.hasOneUse() ||
53379+ !(sd_match(StoredVal, m_And(m_Specific(LoadVal),
53380+ m_Not(m_Shl(m_One(), m_Value(ShAmt))))) ||
53381+ sd_match(StoredVal,
53382+ m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||
53383+ sd_match(StoredVal,
53384+ m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt))))))
53385+ return SDValue();
53386+
53387+ // Ensure the shift amount is in bounds.
53388+ KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
53389+ if (KnownAmt.getMaxValue().uge(VT.getSizeInBits()))
53390+ return SDValue();
53391+
53392+ // Split the shift into an alignment shift that moves the active i32 block to
53393+ // the bottom bits for truncation and a modulo shift that can act on the i32.
53394+ EVT AmtVT = ShAmt.getValueType();
53395+ SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
53396+ DAG.getSignedConstant(-32LL, DL, AmtVT));
53397+ SDValue ModuloAmt =
53398+ DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT));
53399+
53400+ // Compute the byte offset for the i32 block that is changed by the RMW.
53401+ // combineTruncate will adjust the load for us in a similar way.
53402+ EVT PtrVT = St->getBasePtr().getValueType();
53403+ SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT);
53404+ SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs,
53405+ DAG.getShiftAmountConstant(3, PtrVT, DL));
53406+ SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL,
53407+ SDNodeFlags::NoUnsignedWrap);
53408+
53409+ // Reconstruct the BTC/BTR/BTS pattern for the i32 block and store.
53410+ SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt);
53411+ X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
53412+
53413+ SDValue Mask =
53414+ DAG.getNode(ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32),
53415+ DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8));
53416+ if (StoredVal.getOpcode() == ISD::AND)
53417+ Mask = DAG.getNOT(DL, Mask, MVT::i32);
53418+
53419+ SDValue Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask);
53420+ return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(),
53421+ Align(), St->getMemOperand()->getFlags());
53422+ }
53423+
5335053424static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
5335153425 TargetLowering::DAGCombinerInfo &DCI,
5335253426 const X86Subtarget &Subtarget) {
@@ -53573,6 +53647,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
5357353647 }
5357453648 }
5357553649
53650+ if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget))
53651+ return R;
53652+
5357653653 // Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC)
5357753654 // store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC)
5357853655 if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&
@@ -54505,8 +54582,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
5450554582 // truncation, see if we can convert the shift into a pointer offset instead.
5450654583 // Limit this to normal (non-ext) scalar integer loads.
5450754584 if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL &&
54508- Src.hasOneUse() && Src.getOperand(0).hasOneUse() &&
54509- ISD::isNormalLoad(Src.getOperand(0).getNode())) {
54585+ Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) &&
54586+ (Src.getOperand(0).hasOneUse() ||
54587+ !DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) {
5451054588 auto *Ld = cast<LoadSDNode>(Src.getOperand(0));
5451154589 if (Ld->isSimple() && VT.isByteSized() &&
5451254590 isPowerOf2_64(VT.getSizeInBits())) {
@@ -56305,6 +56383,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
5630556383static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5630656384 TargetLowering::DAGCombinerInfo &DCI,
5630756385 const X86Subtarget &Subtarget) {
56386+ using namespace SDPatternMatch;
5630856387 const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
5630956388 const SDValue LHS = N->getOperand(0);
5631056389 const SDValue RHS = N->getOperand(1);
@@ -56363,6 +56442,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5636356442 if (SDValue AndN = MatchAndCmpEq(RHS, LHS))
5636456443 return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
5636556444
56445+ // If we're performing a bit test on a larger than legal type, attempt
56446+ // to (aligned) shift down the value to the bottom 32-bits and then
56447+ // perform the bittest on the i32 value.
56448+ // ICMP_ZERO(AND(X,SHL(1,IDX)))
56449+ // --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31))))
56450+ if (isNullConstant(RHS) &&
56451+ OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) {
56452+ SDValue X, ShAmt;
56453+ if (sd_match(LHS, m_OneUse(m_And(m_Value(X),
56454+ m_Shl(m_One(), m_Value(ShAmt)))))) {
56455+ // Only attempt this if the shift amount is known to be in bounds.
56456+ KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
56457+ if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) {
56458+ EVT AmtVT = ShAmt.getValueType();
56459+ SDValue AlignAmt =
56460+ DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
56461+ DAG.getSignedConstant(-32LL, DL, AmtVT));
56462+ SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
56463+ DAG.getConstant(31, DL, AmtVT));
56464+ SDValue Mask = DAG.getNode(
56465+ ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32),
56466+ DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8));
56467+ X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt);
56468+ X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
56469+ X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask);
56470+ return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32),
56471+ CC);
56472+ }
56473+ }
56474+ }
56475+
5636656476 // cmpeq(trunc(x),C) --> cmpeq(x,C)
5636756477 // cmpne(trunc(x),C) --> cmpne(x,C)
5636856478 // iff x upper bits are zero.
0 commit comments