diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 42ddb32d24093..1439683dc5e96 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -28,6 +28,7 @@ #include "llvm/Analysis/TargetTransformInfoImpl.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/ISDOpcodes.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/CodeGen/ValueTypes.h" @@ -1244,9 +1245,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { EVT LoadVT = EVT::getEVT(Src); unsigned LType = ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD); - if (DstLT.first == SrcLT.first && - TLI->isLoadExtLegal(LType, ExtVT, LoadVT)) - return 0; + + if (I && isa(I->getOperand(0))) { + auto *LI = cast(I->getOperand(0)); + + if (DstLT.first == SrcLT.first && + TLI->isLoadExtLegal(LType, ExtVT, LoadVT, + LI->getPointerAddressSpace())) + return 0; + } } break; case Instruction::AddrSpaceCast: @@ -1531,7 +1538,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { if (Opcode == Instruction::Store) LA = getTLI()->getTruncStoreAction(LT.second, MemVT); else - LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT); + LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT, + AddressSpace); if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) { // This is a vector load/store for some illegal type that is scalarized. diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 88691b931a8d8..0b160443ade59 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1472,27 +1472,38 @@ class LLVM_ABI TargetLoweringBase { /// Return how this load with extension should be treated: either it is legal, /// needs to be promoted to a larger size, needs to be expanded to some other /// code sequence, or the target has a custom expander for it. - LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT, - EVT MemVT) const { - if (ValVT.isExtended() || MemVT.isExtended()) return Expand; - unsigned ValI = (unsigned) ValVT.getSimpleVT().SimpleTy; - unsigned MemI = (unsigned) MemVT.getSimpleVT().SimpleTy; + LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT, EVT MemVT, + unsigned AddrSpace) const { + if (ValVT.isExtended() || MemVT.isExtended()) + return Expand; + unsigned ValI = (unsigned)ValVT.getSimpleVT().SimpleTy; + unsigned MemI = (unsigned)MemVT.getSimpleVT().SimpleTy; assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValI < MVT::VALUETYPE_SIZE && MemI < MVT::VALUETYPE_SIZE && "Table isn't big enough!"); unsigned Shift = 4 * ExtType; + + uint64_t OverrideKey = ((uint64_t)(ValI & 0xFF) << 40) | + ((uint64_t)(MemI & 0xFF) << 32) | + (uint64_t)AddrSpace; + + if (LoadExtActionOverrides.count(OverrideKey)) { + return (LegalizeAction)((LoadExtActionOverrides.at(OverrideKey) >> Shift) & 0xf); + } return (LegalizeAction)((LoadExtActions[ValI][MemI] >> Shift) & 0xf); } /// Return true if the specified load with extension is legal on this target. - bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT) const { - return getLoadExtAction(ExtType, ValVT, MemVT) == Legal; + bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT, + unsigned AddrSpace) const { + return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal; } /// Return true if the specified load with extension is legal or custom /// on this target. - bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT) const { - return getLoadExtAction(ExtType, ValVT, MemVT) == Legal || - getLoadExtAction(ExtType, ValVT, MemVT) == Custom; + bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT, + unsigned AddrSpace) const { + return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal || + getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Custom; } /// Same as getLoadExtAction, but for atomic loads. @@ -2634,23 +2645,38 @@ class LLVM_ABI TargetLoweringBase { /// Indicate that the specified load with extension does not work with the /// specified type and indicate what to do about it. void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, - LegalizeAction Action) { + LegalizeAction Action, unsigned AddrSpace = ~0) { assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() && MemVT.isValid() && "Table isn't big enough!"); assert((unsigned)Action < 0x10 && "too many bits for bitfield array"); + unsigned Shift = 4 * ExtType; - LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &= ~((uint16_t)0xF << Shift); - LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action << Shift; + + if (AddrSpace == ~((unsigned)0)) { + LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &= + ~((uint16_t)0xF << Shift); + LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action + << Shift; + } else { + uint64_t OverrideKey = ((uint64_t)(ValVT.SimpleTy & 0xFF) << 40) | + ((uint64_t)(MemVT.SimpleTy & 0xFF) << 32) | + (uint64_t)AddrSpace; + uint16_t &OverrideVal = LoadExtActionOverrides[OverrideKey]; + + OverrideVal &= ~((uint16_t)0xF << Shift); + OverrideVal |= (uint16_t)Action << Shift; + } } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, MVT MemVT, - LegalizeAction Action) { + LegalizeAction Action, unsigned AddrSpace = ~0) { for (auto ExtType : ExtTypes) - setLoadExtAction(ExtType, ValVT, MemVT, Action); + setLoadExtAction(ExtType, ValVT, MemVT, Action, AddrSpace); } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, - ArrayRef MemVTs, LegalizeAction Action) { + ArrayRef MemVTs, LegalizeAction Action, + unsigned AddrSpace = ~0) { for (auto MemVT : MemVTs) - setLoadExtAction(ExtTypes, ValVT, MemVT, Action); + setLoadExtAction(ExtTypes, ValVT, MemVT, Action, AddrSpace); } /// Let target indicate that an extending atomic load of the specified type @@ -3126,7 +3152,7 @@ class LLVM_ABI TargetLoweringBase { LType = ISD::SEXTLOAD; } - return isLoadExtLegal(LType, VT, LoadVT); + return isLoadExtLegal(LType, VT, LoadVT, Load->getPointerAddressSpace()); } /// Return true if any actual instruction that defines a value of type FromTy @@ -3753,8 +3779,11 @@ class LLVM_ABI TargetLoweringBase { /// For each load extension type and each value type, keep a LegalizeAction /// that indicates how instruction selection should deal with a load of a /// specific value type and extension type. Uses 4-bits to store the action - /// for each of the 4 load ext types. + /// for each of the 4 load ext types. These actions can be specified for each + /// address space. uint16_t LoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE]; + using LoadExtActionOverrideMap = std::map; + LoadExtActionOverrideMap LoadExtActionOverrides; /// Similar to LoadExtActions, but for atomic loads. Only Legal or Expand /// (default) values are supported. diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index eb73d01b3558c..0b7ddf1211f54 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -7347,7 +7347,8 @@ bool CodeGenPrepare::optimizeLoadExt(LoadInst *Load) { // Reject cases that won't be matched as extloads. if (!LoadResultVT.bitsGT(TruncVT) || !TruncVT.isRound() || - !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT)) + !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT, + Load->getPointerAddressSpace())) return false; IRBuilder<> Builder(Load->getNextNode()); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 309f1bea8b77c..8a8de4ba97f92 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6889,7 +6889,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (ExtVT == LoadedVT && (!LegalOperations || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) { + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, + LoadN->getAddressSpace()))) { // ZEXTLOAD will match without needing to change the size of the value being // loaded. return true; @@ -6904,8 +6905,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound()) return false; - if (LegalOperations && - !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT)) + if (LegalOperations && !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, + LoadN->getAddressSpace())) return false; if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0)) @@ -6967,8 +6968,8 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, if (!SDValue(Load, 0).hasOneUse()) return false; - if (LegalOperations && - !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT)) + if (LegalOperations && !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT, + Load->getAddressSpace())) return false; // For the transform to be legal, the load must produce only two values @@ -7480,7 +7481,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; - if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) { + if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT, + MLoad->getAddressSpace())) { // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type @@ -7631,9 +7633,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is // actually legal and isn't going to get expanded, else this is a false // optimisation. - bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD, - Load->getValueType(0), - Load->getMemoryVT()); + bool CanZextLoadProfitably = + TLI.isLoadExtLegal(ISD::ZEXTLOAD, Load->getValueType(0), + Load->getMemoryVT(), Load->getAddressSpace()); // Resize the constant to the same size as the original memory access before // extension. If it is still the AllOnesValue then this AND is completely @@ -7825,7 +7827,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize); if (DAG.MaskedValueIsZero(N1, ExtBits) && ((!LegalOperations && LN0->isSimple()) || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) { + TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN0->getAddressSpace()))) { SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); @@ -9747,10 +9749,13 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { // Before legalize we can introduce too wide illegal loads which will be later // split into legal sized loads. This enables us to combine i64 load by i8 // patterns to a couple of i32 loads on 32 bit targets. - if (LegalOperations && - !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT, - MemVT)) - return SDValue(); + if (LegalOperations) { + for (auto *L : Loads) { + if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT, + MemVT, L->getAddressSpace())) + return SDValue(); + } + } // Check if the bytes of the OR we are looking at match with either big or // little endian value load @@ -13425,9 +13430,11 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { unsigned WideWidth = WideVT.getScalarSizeInBits(); bool IsSigned = isSignedIntSetCC(CC); auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD; - if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && - SetCCWidth != 1 && SetCCWidth < WideWidth && - TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) && + if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && SetCCWidth != 1 && + SetCCWidth < WideWidth && + TLI.isLoadExtLegalOrCustom( + LoadExtOpcode, WideVT, NarrowVT, + cast(LHS)->getAddressSpace()) && TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) { // Both compare operands can be widened for free. The LHS can use an // extended load, and the RHS is a constant: @@ -13874,8 +13881,10 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, // Combine2), so we should conservatively check the OperationAction. LoadSDNode *Load1 = cast(Op1); LoadSDNode *Load2 = cast(Op2); - if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) || - !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) || + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT(), + Load1->getAddressSpace()) || + !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT(), + Load2->getAddressSpace()) || (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes && TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal)) return SDValue(); @@ -14099,13 +14108,15 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { // Try to split the vector types to get down to legal types. EVT SplitSrcVT = SrcVT; EVT SplitDstVT = DstVT; - while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) && + while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT, + LN0->getAddressSpace()) && SplitSrcVT.getVectorNumElements() > 1) { SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first; SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first; } - if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT)) + if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT, + LN0->getAddressSpace())) return SDValue(); assert(!DstVT.isScalableVector() && "Unexpected scalable vector type"); @@ -14178,7 +14189,7 @@ SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) { return SDValue(); LoadSDNode *Load = cast(N1.getOperand(0)); EVT MemVT = Load->getMemoryVT(); - if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) || + if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, Load->getAddressSpace()) || Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed()) return SDValue(); @@ -14286,9 +14297,8 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner, LoadSDNode *LN0 = cast(N0); EVT MemVT = LN0->getMemoryVT(); - if ((LegalOperations || !LN0->isSimple() || - VT.isVector()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT)) + if ((LegalOperations || !LN0->isSimple() || VT.isVector()) && + !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT, LN0->getAddressSpace())) return SDValue(); SDValue ExtLoad = @@ -14330,12 +14340,13 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, } } + LoadSDNode *LN0 = cast(N0); // TODO: isFixedLengthVector() should be removed and any negative effects on // code generation being the result of that target's implementation of // isVectorLoadExtDesirable(). - if ((LegalOperations || VT.isFixedLengthVector() || - !cast(N0)->isSimple()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())) + if ((LegalOperations || VT.isFixedLengthVector() || !LN0->isSimple()) && + !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(), + LN0->getAddressSpace())) return {}; bool DoXform = true; @@ -14347,7 +14358,6 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, if (!DoXform) return {}; - LoadSDNode *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(), LN0->getBasePtr(), N0.getValueType(), LN0->getMemOperand()); @@ -14377,8 +14387,9 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT, if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD) return SDValue(); - if ((LegalOperations || !cast(N0)->isSimple()) && - !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0))) + if ((LegalOperations || !Ld->isSimple()) && + !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0), + Ld->getAddressSpace())) return SDValue(); if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0))) @@ -14522,7 +14533,8 @@ SDValue DAGCombiner::foldSextSetcc(SDNode *N) { if (!(ISD::isNON_EXTLoad(V.getNode()) && ISD::isUNINDEXEDLoad(V.getNode()) && cast(V)->isSimple() && - TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType()))) + TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType(), + cast(V)->getAddressSpace()))) return false; // Non-chain users of this value must either be the setcc in this @@ -14719,8 +14731,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { LoadSDNode *LN00 = cast(N0.getOperand(0)); EVT MemVT = LN00->getMemoryVT(); - if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) && - LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { + if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT, LN00->getAddressSpace()) && + LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { SmallVector SetCCs; bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), ISD::SIGN_EXTEND, SetCCs, TLI); @@ -15037,7 +15049,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { LoadSDNode *LN00 = cast(N0.getOperand(0)); EVT MemVT = LN00->getMemoryVT(); - if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) && + if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN00->getAddressSpace()) && LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) { bool DoXform = true; SmallVector SetCCs; @@ -15268,7 +15280,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { return foldedExt; } else if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && - TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) { + TLI.isLoadExtLegalOrCustom( + ISD::EXTLOAD, VT, N0.getValueType(), + cast(N0)->getAddressSpace())) { bool DoXform = true; SmallVector SetCCs; if (!N0.hasOneUse()) @@ -15303,7 +15317,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { LoadSDNode *LN0 = cast(N0); ISD::LoadExtType ExtType = LN0->getExtensionType(); EVT MemVT = LN0->getMemoryVT(); - if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) { + if (!LegalOperations || + TLI.isLoadExtLegal(ExtType, VT, MemVT, LN0->getAddressSpace())) { SDValue ExtLoad = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); @@ -15617,7 +15632,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one()); // If the mask is smaller, recompute the type. if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) && - TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) + TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT, + LN->getAddressSpace())) ExtVT = MaskedVT; } else if (ExtType == ISD::ZEXTLOAD && ShiftMask.isShiftedMask(Offset, ActiveBits) && @@ -15626,7 +15642,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { // If the mask is shifted we can use a narrower load and a shl to insert // the trailing zeros. if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) && - TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) { + TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT, + LN->getAddressSpace())) { ExtVT = MaskedVT; ShAmt = Offset + ShAmt; ShiftedOffset = Offset; @@ -15852,7 +15869,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { ExtVT == cast(N0)->getMemoryVT() && ((!LegalOperations && cast(N0)->isSimple() && N0.hasOneUse()) || - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT, + cast(N0)->getAddressSpace()))) { auto *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), @@ -15867,7 +15885,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse() && ExtVT == cast(N0)->getMemoryVT() && ((!LegalOperations && cast(N0)->isSimple()) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT, + cast(N0)->getAddressSpace()))) { auto *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), @@ -15882,7 +15901,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { if (MaskedLoadSDNode *Ld = dyn_cast(N0)) { if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() && Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD && - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT, Ld->getAddressSpace())) { SDValue ExtMaskedLoad = DAG.getMaskedLoad( VT, DL, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(), @@ -19222,7 +19241,8 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { // fold (fpext (load x)) -> (fpext (fptrunc (extload x))) if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() && - TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) { + TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType(), + cast(N0)->getAddressSpace())) { LoadSDNode *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT, LN0->getChain(), @@ -22274,12 +22294,16 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl &StoreNodes, } else if (TLI.getTypeAction(Context, StoreTy) == TargetLowering::TypePromoteInteger) { EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy); + unsigned AS = LoadNodes[i].MemNode->getAddressSpace(); if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG.getMachineFunction()) && - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) && - TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) && + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy, + AS) && + TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy, + AS) && + TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy, + AS) && TLI.allowsMemoryAccess(Context, DL, StoreTy, *FirstInChain->getMemOperand(), &IsFastSt) && IsFastSt && diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index 5fb7e63cfb605..532a8c490b481 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -348,7 +348,7 @@ SelectionDAGLegalize::ExpandConstantFP(ConstantFPSDNode *CFP, bool UseCP) { if (ConstantFPSDNode::isValueValidForType(SVT, APF) && // Only do this if the target has a native EXTLOAD instruction from // smaller type. - TLI.isLoadExtLegal(ISD::EXTLOAD, OrigVT, SVT) && + TLI.isLoadExtLegal(ISD::EXTLOAD, OrigVT, SVT, 0) && TLI.ShouldShrinkFPConstant(OrigVT)) { Type *SType = SVT.getTypeForEVT(*DAG.getContext()); LLVMC = cast(ConstantFoldCastOperand( @@ -740,8 +740,9 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { // nice to have an effective generic way of getting these benefits... // Until such a way is found, don't insist on promoting i1 here. (SrcVT != MVT::i1 || - TLI.getLoadExtAction(ExtType, Node->getValueType(0), MVT::i1) == - TargetLowering::Promote)) { + TLI.getLoadExtAction(ExtType, Node->getValueType(0), MVT::i1, + LD->getAddressSpace()) == + TargetLowering::Promote)) { // Promote to a byte-sized load if not loading an integral number of // bytes. For example, promote EXTLOAD:i20 -> EXTLOAD:i24. unsigned NewWidth = SrcVT.getStoreSizeInBits(); @@ -852,7 +853,7 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { } else { bool isCustom = false; switch (TLI.getLoadExtAction(ExtType, Node->getValueType(0), - SrcVT.getSimpleVT())) { + SrcVT.getSimpleVT(), LD->getAddressSpace())) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Custom: isCustom = true; @@ -880,13 +881,15 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { case TargetLowering::Expand: { EVT DestVT = Node->getValueType(0); - if (!TLI.isLoadExtLegal(ISD::EXTLOAD, DestVT, SrcVT)) { + if (!TLI.isLoadExtLegal(ISD::EXTLOAD, DestVT, SrcVT, + LD->getAddressSpace())) { // If the source type is not legal, see if there is a legal extload to // an intermediate type that we can then extend further. EVT LoadVT = TLI.getRegisterType(SrcVT.getSimpleVT()); if ((LoadVT.isFloatingPoint() == SrcVT.isFloatingPoint()) && (TLI.isTypeLegal(SrcVT) || // Same as SrcVT == LoadVT? - TLI.isLoadExtLegal(ExtType, LoadVT, SrcVT))) { + TLI.isLoadExtLegal(ExtType, LoadVT, SrcVT, + LD->getAddressSpace()))) { // If we are loading a legal type, this is a non-extload followed by a // full extend. ISD::LoadExtType MidExtType = @@ -1846,7 +1849,8 @@ SDValue SelectionDAGLegalize::EmitStackConvert(SDValue SrcOp, EVT SlotVT, if ((SrcVT.bitsGT(SlotVT) && !TLI.isTruncStoreLegalOrCustom(SrcOp.getValueType(), SlotVT)) || (SlotVT.bitsLT(DestVT) && - !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT))) + !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT, + DAG.getDataLayout().getAllocaAddrSpace()))) return SDValue(); // Create the stack frame object. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 8e423c4f83b38..be8e780a6f55d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -301,7 +301,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { ISD::LoadExtType ExtType = LD->getExtensionType(); EVT LoadedVT = LD->getMemoryVT(); if (LoadedVT.isVector() && ExtType != ISD::NON_EXTLOAD) - Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT); + Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT, + LD->getAddressSpace()); break; } case ISD::STORE: { diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index cc503d324e74b..501b9dd3294b7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12378,7 +12378,8 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT, if (ResultVT.bitsGT(VecEltVT)) { // If the result type of vextract is wider than the load, then issue an // extending load instead. - ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) + ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT, + OriginalLoad->getAddressSpace()) ? ISD::ZEXTLOAD : ISD::EXTLOAD; Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(), diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index c23281a820b2b..1343ffe1db70e 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -729,6 +729,7 @@ void TargetLoweringBase::initActions() { // All operations default to being supported. memset(OpActions, 0, sizeof(OpActions)); memset(LoadExtActions, 0, sizeof(LoadExtActions)); + LoadExtActionOverrides.clear(); memset(TruncStoreActions, 0, sizeof(TruncStoreActions)); memset(IndexedModeActions, 0, sizeof(IndexedModeActions)); memset(CondCodeActions, 0, sizeof(CondCodeActions)); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index dc8e7c84f5e2c..c7467d8ddd25c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6752,7 +6752,8 @@ bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { // results in just one set of predicate unpacks at the start, instead of // multiple sets of vector unpacks after each load. if (auto *Ld = dyn_cast(ExtVal->getOperand(0))) { - if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0))) { + if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0), + Ld->getAddressSpace())) { // Disable extending masked loads for fixed-width for now, since the code // quality doesn't look great. if (!ExtVT.isScalableVector()) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 931a10b700c87..4a6192cba9345 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -60344,7 +60344,7 @@ static SDValue combineEXTEND_VECTOR_INREG(SDNode *N, SelectionDAG &DAG, ? ISD::SEXTLOAD : ISD::ZEXTLOAD; EVT MemVT = VT.changeVectorElementType(SVT); - if (TLI.isLoadExtLegal(Ext, VT, MemVT)) { + if (TLI.isLoadExtLegal(Ext, VT, MemVT, Ld->getAddressSpace())) { SDValue Load = DAG.getExtLoad( Ext, DL, VT, Ld->getChain(), Ld->getBasePtr(), Ld->getPointerInfo(), MemVT, Ld->getBaseAlign(), Ld->getMemOperand()->getFlags());