From 99e87a3843aef19f4bf3f3f3c35726360f99ef99 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Tue, 7 Oct 2025 17:10:47 -0700 Subject: [PATCH 1/4] Initial conversion of LoadExtActions to map --- llvm/include/llvm/CodeGen/TargetLowering.h | 46 ++++++++++++++-------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 88691b931a8d8..95c384d09d2b5 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1472,27 +1472,34 @@ class LLVM_ABI TargetLoweringBase { /// Return how this load with extension should be treated: either it is legal, /// needs to be promoted to a larger size, needs to be expanded to some other /// code sequence, or the target has a custom expander for it. - LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT, - EVT MemVT) const { - if (ValVT.isExtended() || MemVT.isExtended()) return Expand; - unsigned ValI = (unsigned) ValVT.getSimpleVT().SimpleTy; - unsigned MemI = (unsigned) MemVT.getSimpleVT().SimpleTy; + LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT, EVT MemVT, + unsigned AddrSpace) const { + if (ValVT.isExtended() || MemVT.isExtended()) + return Expand; + unsigned ValI = (unsigned)ValVT.getSimpleVT().SimpleTy; + unsigned MemI = (unsigned)MemVT.getSimpleVT().SimpleTy; assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValI < MVT::VALUETYPE_SIZE && MemI < MVT::VALUETYPE_SIZE && "Table isn't big enough!"); unsigned Shift = 4 * ExtType; - return (LegalizeAction)((LoadExtActions[ValI][MemI] >> Shift) & 0xf); + + if (!LoadExtActions.count(AddrSpace)) { + return Legal; // default + } + return ( + LegalizeAction)((LoadExtActions.at(AddrSpace)[ValI][MemI] >> Shift) & + 0xf); } /// Return true if the specified load with extension is legal on this target. - bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT) const { - return getLoadExtAction(ExtType, ValVT, MemVT) == Legal; + bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT, unsigned AddrSpace) const { + return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal; } /// Return true if the specified load with extension is legal or custom /// on this target. - bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT) const { - return getLoadExtAction(ExtType, ValVT, MemVT) == Legal || - getLoadExtAction(ExtType, ValVT, MemVT) == Custom; + bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT, unsigned AddrSpace) const { + return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal || + getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Custom; } /// Same as getLoadExtAction, but for atomic loads. @@ -2634,13 +2641,15 @@ class LLVM_ABI TargetLoweringBase { /// Indicate that the specified load with extension does not work with the /// specified type and indicate what to do about it. void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, - LegalizeAction Action) { + LegalizeAction Action, unsigned AddrSpace) { assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() && MemVT.isValid() && "Table isn't big enough!"); assert((unsigned)Action < 0x10 && "too many bits for bitfield array"); unsigned Shift = 4 * ExtType; - LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &= ~((uint16_t)0xF << Shift); - LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action << Shift; + LoadExtActions[AddrSpace][ValVT.SimpleTy][MemVT.SimpleTy] &= + ~((uint16_t)0xF << Shift); + LoadExtActions[AddrSpace][ValVT.SimpleTy][MemVT.SimpleTy] |= + (uint16_t)Action << Shift; } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, MVT MemVT, LegalizeAction Action) { @@ -3753,8 +3762,13 @@ class LLVM_ABI TargetLoweringBase { /// For each load extension type and each value type, keep a LegalizeAction /// that indicates how instruction selection should deal with a load of a /// specific value type and extension type. Uses 4-bits to store the action - /// for each of the 4 load ext types. - uint16_t LoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE]; + /// for each of the 4 load ext types. These actions can be specified for each + /// address space. + using LoadExtActionMapTy = + std::array, + MVT::VALUETYPE_SIZE>; + using LoadExtActionMap = std::map; + LoadExtActionMap LoadExtActions; /// Similar to LoadExtActions, but for atomic loads. Only Legal or Expand /// (default) values are supported. From ebcf9ba3a60dbc7475e3eaf132d53282d3a304b7 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Wed, 8 Oct 2025 10:09:29 -0700 Subject: [PATCH 2/4] Finish making address-space sensitive, and fix AMDGPU backend --- llvm/include/llvm/CodeGen/BasicTTIImpl.h | 14 +- llvm/include/llvm/CodeGen/TargetLowering.h | 31 +++- llvm/lib/CodeGen/CodeGenPrepare.cpp | 2 +- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 98 ++++++++----- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 15 +- .../SelectionDAG/LegalizeVectorOps.cpp | 2 +- .../CodeGen/SelectionDAG/TargetLowering.cpp | 3 +- llvm/lib/CodeGen/TargetLoweringBase.cpp | 4 +- .../Target/AArch64/AArch64ISelLowering.cpp | 3 +- llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 134 ++++++++++-------- llvm/lib/Target/AMDGPU/R600ISelLowering.cpp | 42 ++++-- llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +- 12 files changed, 219 insertions(+), 131 deletions(-) diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 42ddb32d24093..af087b154c7f7 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -28,6 +28,7 @@ #include "llvm/Analysis/TargetTransformInfoImpl.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/ISDOpcodes.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/CodeGen/ValueTypes.h" @@ -1244,9 +1245,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { EVT LoadVT = EVT::getEVT(Src); unsigned LType = ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD); - if (DstLT.first == SrcLT.first && - TLI->isLoadExtLegal(LType, ExtVT, LoadVT)) - return 0; + + if (I && isa(I->getOperand(0))) { + auto *LI = cast(I->getOperand(0)); + + if (DstLT.first == SrcLT.first && + TLI->isLoadExtLegal(LType, ExtVT, LoadVT, LI->getPointerAddressSpace())) + return 0; + } } break; case Instruction::AddrSpaceCast: @@ -1531,7 +1537,7 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { if (Opcode == Instruction::Store) LA = getTLI()->getTruncStoreAction(LT.second, MemVT); else - LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT); + LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT, AddressSpace); if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) { // This is a vector load/store for some illegal type that is scalarized. diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 95c384d09d2b5..a5af81aadc33f 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1483,7 +1483,11 @@ class LLVM_ABI TargetLoweringBase { unsigned Shift = 4 * ExtType; if (!LoadExtActions.count(AddrSpace)) { - return Legal; // default + if (MemVT == MVT::i2 || MemVT == MVT::i4 || MemVT == MVT::v128i2 || + MemVT == MVT::v64i4) + return Expand; + + return Legal; } return ( LegalizeAction)((LoadExtActions.at(AddrSpace)[ValI][MemI] >> Shift) & @@ -2641,10 +2645,22 @@ class LLVM_ABI TargetLoweringBase { /// Indicate that the specified load with extension does not work with the /// specified type and indicate what to do about it. void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, - LegalizeAction Action, unsigned AddrSpace) { + LegalizeAction Action, unsigned AddrSpace = 0) { assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() && MemVT.isValid() && "Table isn't big enough!"); assert((unsigned)Action < 0x10 && "too many bits for bitfield array"); + + if (!LoadExtActions.count(AddrSpace)) { + LoadExtActions[AddrSpace]; // Initialize the map for the addrspace + + for (MVT AVT : MVT::all_valuetypes()) { + for (MVT VT : {MVT::i2, MVT::i4, MVT::v128i2, MVT::v64i4}) { + setLoadExtAction(ISD::EXTLOAD, AVT, VT, Expand, AddrSpace); + setLoadExtAction(ISD::ZEXTLOAD, AVT, VT, Expand, AddrSpace); + } + } + } + unsigned Shift = 4 * ExtType; LoadExtActions[AddrSpace][ValVT.SimpleTy][MemVT.SimpleTy] &= ~((uint16_t)0xF << Shift); @@ -2652,14 +2668,15 @@ class LLVM_ABI TargetLoweringBase { (uint16_t)Action << Shift; } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, MVT MemVT, - LegalizeAction Action) { + LegalizeAction Action, unsigned AddrSpace = 0) { for (auto ExtType : ExtTypes) - setLoadExtAction(ExtType, ValVT, MemVT, Action); + setLoadExtAction(ExtType, ValVT, MemVT, Action, AddrSpace); } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, - ArrayRef MemVTs, LegalizeAction Action) { + ArrayRef MemVTs, LegalizeAction Action, + unsigned AddrSpace = 0) { for (auto MemVT : MemVTs) - setLoadExtAction(ExtTypes, ValVT, MemVT, Action); + setLoadExtAction(ExtTypes, ValVT, MemVT, Action, AddrSpace); } /// Let target indicate that an extending atomic load of the specified type @@ -3135,7 +3152,7 @@ class LLVM_ABI TargetLoweringBase { LType = ISD::SEXTLOAD; } - return isLoadExtLegal(LType, VT, LoadVT); + return isLoadExtLegal(LType, VT, LoadVT, Load->getPointerAddressSpace()); } /// Return true if any actual instruction that defines a value of type FromTy diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index eb73d01b3558c..1bcbc64f3105b 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -7347,7 +7347,7 @@ bool CodeGenPrepare::optimizeLoadExt(LoadInst *Load) { // Reject cases that won't be matched as extloads. if (!LoadResultVT.bitsGT(TruncVT) || !TruncVT.isRound() || - !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT)) + !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT, Load->getPointerAddressSpace())) return false; IRBuilder<> Builder(Load->getNextNode()); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 309f1bea8b77c..b0519302adc34 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6889,7 +6889,7 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (ExtVT == LoadedVT && (!LegalOperations || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) { + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, LoadN->getAddressSpace()))) { // ZEXTLOAD will match without needing to change the size of the value being // loaded. return true; @@ -6905,7 +6905,7 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, return false; if (LegalOperations && - !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT)) + !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, LoadN->getAddressSpace())) return false; if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0)) @@ -6968,7 +6968,7 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, return false; if (LegalOperations && - !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT)) + !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT, Load->getAddressSpace())) return false; // For the transform to be legal, the load must produce only two values @@ -7480,7 +7480,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; - if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) { + if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT, MLoad->getAddressSpace())) { // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type @@ -7633,7 +7633,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // optimisation. bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD, Load->getValueType(0), - Load->getMemoryVT()); + Load->getMemoryVT(), + Load->getAddressSpace()); // Resize the constant to the same size as the original memory access before // extension. If it is still the AllOnesValue then this AND is completely @@ -7825,7 +7826,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize); if (DAG.MaskedValueIsZero(N1, ExtBits) && ((!LegalOperations && LN0->isSimple()) || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) { + TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN0->getAddressSpace()))) { SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); @@ -9747,10 +9748,13 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { // Before legalize we can introduce too wide illegal loads which will be later // split into legal sized loads. This enables us to combine i64 load by i8 // patterns to a couple of i32 loads on 32 bit targets. - if (LegalOperations && - !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT, - MemVT)) - return SDValue(); + if (LegalOperations) { + for (auto *L : Loads) { + if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT, + MemVT, L->getAddressSpace())) + return SDValue(); + } + } // Check if the bytes of the OR we are looking at match with either big or // little endian value load @@ -13425,9 +13429,11 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { unsigned WideWidth = WideVT.getScalarSizeInBits(); bool IsSigned = isSignedIntSetCC(CC); auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD; - if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && - SetCCWidth != 1 && SetCCWidth < WideWidth && - TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) && + if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && SetCCWidth != 1 && + SetCCWidth < WideWidth && + TLI.isLoadExtLegalOrCustom( + LoadExtOpcode, WideVT, NarrowVT, + cast(LHS)->getAddressSpace()) && TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) { // Both compare operands can be widened for free. The LHS can use an // extended load, and the RHS is a constant: @@ -13874,8 +13880,10 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, // Combine2), so we should conservatively check the OperationAction. LoadSDNode *Load1 = cast(Op1); LoadSDNode *Load2 = cast(Op2); - if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) || - !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) || + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT(), + Load1->getAddressSpace()) || + !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT(), + Load2->getAddressSpace()) || (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes && TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal)) return SDValue(); @@ -14099,13 +14107,15 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) { // Try to split the vector types to get down to legal types. EVT SplitSrcVT = SrcVT; EVT SplitDstVT = DstVT; - while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) && + while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT, + LN0->getAddressSpace()) && SplitSrcVT.getVectorNumElements() > 1) { SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first; SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first; } - if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT)) + if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT, + LN0->getAddressSpace())) return SDValue(); assert(!DstVT.isScalableVector() && "Unexpected scalable vector type"); @@ -14178,7 +14188,7 @@ SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) { return SDValue(); LoadSDNode *Load = cast(N1.getOperand(0)); EVT MemVT = Load->getMemoryVT(); - if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) || + if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, Load->getAddressSpace()) || Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed()) return SDValue(); @@ -14288,7 +14298,7 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner, EVT MemVT = LN0->getMemoryVT(); if ((LegalOperations || !LN0->isSimple() || VT.isVector()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT)) + !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT, LN0->getAddressSpace())) return SDValue(); SDValue ExtLoad = @@ -14330,12 +14340,13 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, } } + LoadSDNode *LN0 = cast(N0); // TODO: isFixedLengthVector() should be removed and any negative effects on // code generation being the result of that target's implementation of // isVectorLoadExtDesirable(). if ((LegalOperations || VT.isFixedLengthVector() || - !cast(N0)->isSimple()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType())) + !LN0->isSimple()) && + !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(), LN0->getAddressSpace())) return {}; bool DoXform = true; @@ -14347,7 +14358,6 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, if (!DoXform) return {}; - LoadSDNode *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(), LN0->getBasePtr(), N0.getValueType(), LN0->getMemOperand()); @@ -14377,8 +14387,8 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT, if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD) return SDValue(); - if ((LegalOperations || !cast(N0)->isSimple()) && - !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0))) + if ((LegalOperations || !Ld->isSimple()) && + !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0), Ld->getAddressSpace())) return SDValue(); if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0))) @@ -14522,7 +14532,8 @@ SDValue DAGCombiner::foldSextSetcc(SDNode *N) { if (!(ISD::isNON_EXTLoad(V.getNode()) && ISD::isUNINDEXEDLoad(V.getNode()) && cast(V)->isSimple() && - TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType()))) + TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType(), + cast(V)->getAddressSpace()))) return false; // Non-chain users of this value must either be the setcc in this @@ -14719,7 +14730,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { LoadSDNode *LN00 = cast(N0.getOperand(0)); EVT MemVT = LN00->getMemoryVT(); - if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) && + if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT, LN00->getAddressSpace()) && LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { SmallVector SetCCs; bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), @@ -15037,7 +15048,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) { LoadSDNode *LN00 = cast(N0.getOperand(0)); EVT MemVT = LN00->getMemoryVT(); - if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) && + if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN00->getAddressSpace()) && LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) { bool DoXform = true; SmallVector SetCCs; @@ -15268,7 +15279,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { return foldedExt; } else if (ISD::isNON_EXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && - TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) { + TLI.isLoadExtLegalOrCustom( + ISD::EXTLOAD, VT, N0.getValueType(), + cast(N0)->getAddressSpace())) { bool DoXform = true; SmallVector SetCCs; if (!N0.hasOneUse()) @@ -15303,7 +15316,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { LoadSDNode *LN0 = cast(N0); ISD::LoadExtType ExtType = LN0->getExtensionType(); EVT MemVT = LN0->getMemoryVT(); - if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT)) { + if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT, LN0->getAddressSpace())) { SDValue ExtLoad = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); @@ -15617,7 +15630,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { EVT::getIntegerVT(*DAG.getContext(), ShiftMask.countr_one()); // If the mask is smaller, recompute the type. if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) && - TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) + TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT, + LN->getAddressSpace())) ExtVT = MaskedVT; } else if (ExtType == ISD::ZEXTLOAD && ShiftMask.isShiftedMask(Offset, ActiveBits) && @@ -15626,7 +15640,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { // If the mask is shifted we can use a narrower load and a shl to insert // the trailing zeros. if (((Offset + ActiveBits) <= ExtVT.getScalarSizeInBits()) && - TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) { + TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT, + LN->getAddressSpace())) { ExtVT = MaskedVT; ShAmt = Offset + ShAmt; ShiftedOffset = Offset; @@ -15852,7 +15867,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { ExtVT == cast(N0)->getMemoryVT() && ((!LegalOperations && cast(N0)->isSimple() && N0.hasOneUse()) || - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT, + cast(N0)->getAddressSpace()))) { auto *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), @@ -15867,7 +15883,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse() && ExtVT == cast(N0)->getMemoryVT() && ((!LegalOperations && cast(N0)->isSimple()) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT))) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT, + cast(N0)->getAddressSpace()))) { auto *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), @@ -15882,7 +15899,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { if (MaskedLoadSDNode *Ld = dyn_cast(N0)) { if (ExtVT == Ld->getMemoryVT() && N0.hasOneUse() && Ld->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD && - TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) { + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT, Ld->getAddressSpace())) { SDValue ExtMaskedLoad = DAG.getMaskedLoad( VT, DL, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(), Ld->getPassThru(), ExtVT, Ld->getMemOperand(), @@ -19222,7 +19239,8 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { // fold (fpext (load x)) -> (fpext (fptrunc (extload x))) if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() && - TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) { + TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType(), + cast(N0)->getAddressSpace())) { LoadSDNode *LN0 = cast(N0); SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, VT, LN0->getChain(), @@ -22274,12 +22292,16 @@ bool DAGCombiner::tryStoreMergeOfLoads(SmallVectorImpl &StoreNodes, } else if (TLI.getTypeAction(Context, StoreTy) == TargetLowering::TypePromoteInteger) { EVT LegalizedStoredValTy = TLI.getTypeToTransformTo(Context, StoreTy); + unsigned AS = LoadNodes[i].MemNode->getAddressSpace(); if (TLI.isTruncStoreLegal(LegalizedStoredValTy, StoreTy) && TLI.canMergeStoresTo(FirstStoreAS, LegalizedStoredValTy, DAG.getMachineFunction()) && - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy) && - TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy) && - TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy) && + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValTy, StoreTy, + AS) && + TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValTy, StoreTy, + AS) && + TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValTy, StoreTy, + AS) && TLI.allowsMemoryAccess(Context, DL, StoreTy, *FirstInChain->getMemOperand(), &IsFastSt) && IsFastSt && diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index 5fb7e63cfb605..99dcf23e9b121 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -348,7 +348,7 @@ SelectionDAGLegalize::ExpandConstantFP(ConstantFPSDNode *CFP, bool UseCP) { if (ConstantFPSDNode::isValueValidForType(SVT, APF) && // Only do this if the target has a native EXTLOAD instruction from // smaller type. - TLI.isLoadExtLegal(ISD::EXTLOAD, OrigVT, SVT) && + TLI.isLoadExtLegal(ISD::EXTLOAD, OrigVT, SVT, 0) && TLI.ShouldShrinkFPConstant(OrigVT)) { Type *SType = SVT.getTypeForEVT(*DAG.getContext()); LLVMC = cast(ConstantFoldCastOperand( @@ -740,7 +740,8 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { // nice to have an effective generic way of getting these benefits... // Until such a way is found, don't insist on promoting i1 here. (SrcVT != MVT::i1 || - TLI.getLoadExtAction(ExtType, Node->getValueType(0), MVT::i1) == + TLI.getLoadExtAction(ExtType, Node->getValueType(0), MVT::i1, + LD->getAddressSpace()) == TargetLowering::Promote)) { // Promote to a byte-sized load if not loading an integral number of // bytes. For example, promote EXTLOAD:i20 -> EXTLOAD:i24. @@ -852,7 +853,7 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { } else { bool isCustom = false; switch (TLI.getLoadExtAction(ExtType, Node->getValueType(0), - SrcVT.getSimpleVT())) { + SrcVT.getSimpleVT(), LD->getAddressSpace())) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Custom: isCustom = true; @@ -880,13 +881,15 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { case TargetLowering::Expand: { EVT DestVT = Node->getValueType(0); - if (!TLI.isLoadExtLegal(ISD::EXTLOAD, DestVT, SrcVT)) { + if (!TLI.isLoadExtLegal(ISD::EXTLOAD, DestVT, SrcVT, + LD->getAddressSpace())) { // If the source type is not legal, see if there is a legal extload to // an intermediate type that we can then extend further. EVT LoadVT = TLI.getRegisterType(SrcVT.getSimpleVT()); if ((LoadVT.isFloatingPoint() == SrcVT.isFloatingPoint()) && (TLI.isTypeLegal(SrcVT) || // Same as SrcVT == LoadVT? - TLI.isLoadExtLegal(ExtType, LoadVT, SrcVT))) { + TLI.isLoadExtLegal(ExtType, LoadVT, SrcVT, + LD->getAddressSpace()))) { // If we are loading a legal type, this is a non-extload followed by a // full extend. ISD::LoadExtType MidExtType = @@ -1846,7 +1849,7 @@ SDValue SelectionDAGLegalize::EmitStackConvert(SDValue SrcOp, EVT SlotVT, if ((SrcVT.bitsGT(SlotVT) && !TLI.isTruncStoreLegalOrCustom(SrcOp.getValueType(), SlotVT)) || (SlotVT.bitsLT(DestVT) && - !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT))) + !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT, DAG.getDataLayout().getAllocaAddrSpace()))) return SDValue(); // Create the stack frame object. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 8e423c4f83b38..a25705235cb40 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -301,7 +301,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { ISD::LoadExtType ExtType = LD->getExtensionType(); EVT LoadedVT = LD->getMemoryVT(); if (LoadedVT.isVector() && ExtType != ISD::NON_EXTLOAD) - Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT); + Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT, LD->getAddressSpace()); break; } case ISD::STORE: { diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index cc503d324e74b..501b9dd3294b7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12378,7 +12378,8 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT, if (ResultVT.bitsGT(VecEltVT)) { // If the result type of vextract is wider than the load, then issue an // extending load instead. - ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) + ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT, + OriginalLoad->getAddressSpace()) ? ISD::ZEXTLOAD : ISD::EXTLOAD; Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(), diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index c23281a820b2b..3addb58c06f8f 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -728,7 +728,7 @@ TargetLoweringBase::~TargetLoweringBase() = default; void TargetLoweringBase::initActions() { // All operations default to being supported. memset(OpActions, 0, sizeof(OpActions)); - memset(LoadExtActions, 0, sizeof(LoadExtActions)); + LoadExtActions.clear(); memset(TruncStoreActions, 0, sizeof(TruncStoreActions)); memset(IndexedModeActions, 0, sizeof(IndexedModeActions)); memset(CondCodeActions, 0, sizeof(CondCodeActions)); @@ -751,8 +751,6 @@ void TargetLoweringBase::initActions() { for (MVT AVT : MVT::all_valuetypes()) { for (MVT VT : {MVT::i2, MVT::i4, MVT::v128i2, MVT::v64i4}) { setTruncStoreAction(AVT, VT, Expand); - setLoadExtAction(ISD::EXTLOAD, AVT, VT, Expand); - setLoadExtAction(ISD::ZEXTLOAD, AVT, VT, Expand); } } for (unsigned IM = (unsigned)ISD::PRE_INC; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index dc8e7c84f5e2c..c7467d8ddd25c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6752,7 +6752,8 @@ bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { // results in just one set of predicate unpacks at the start, instead of // multiple sets of vector unpacks after each load. if (auto *Ld = dyn_cast(ExtVal->getOperand(0))) { - if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0))) { + if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0), + Ld->getAddressSpace())) { // Disable extending masked loads for fixed-width for now, since the code // quality doesn't look great. if (!ExtVT.isScalableVector()) diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index a44af5f854c18..3aa8e4602b497 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -23,6 +23,7 @@ #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/Support/AMDGPUAddrSpace.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Target/TargetMachine.h" @@ -178,64 +179,85 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, setOperationAction(ISD::ATOMIC_STORE, MVT::bf16, Promote); AddPromotedToType(ISD::ATOMIC_STORE, MVT::bf16, MVT::i16); - // There are no 64-bit extloads. These should be done as a 32-bit extload and - // an extension to 64-bit. - for (MVT VT : MVT::integer_valuetypes()) - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i64, VT, - Expand); - - for (MVT VT : MVT::integer_valuetypes()) { - if (VT == MVT::i64) - continue; - - for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) { - setLoadExtAction(Op, VT, MVT::i1, Promote); - setLoadExtAction(Op, VT, MVT::i8, Legal); - setLoadExtAction(Op, VT, MVT::i16, Legal); - setLoadExtAction(Op, VT, MVT::i32, Expand); + for (unsigned AddrSpace : { + AMDGPUAS::MAX_AMDGPU_ADDRESS, AMDGPUAS::FLAT_ADDRESS, + AMDGPUAS::GLOBAL_ADDRESS, AMDGPUAS::REGION_ADDRESS, + AMDGPUAS::LOCAL_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS, + AMDGPUAS::PRIVATE_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS_32BIT, + AMDGPUAS::BUFFER_FAT_POINTER, AMDGPUAS::BUFFER_RESOURCE, + AMDGPUAS::BUFFER_STRIDED_POINTER, AMDGPUAS::STREAMOUT_REGISTER, + AMDGPUAS::PARAM_D_ADDRESS, AMDGPUAS::PARAM_I_ADDRESS, + + AMDGPUAS::CONSTANT_BUFFER_0, AMDGPUAS::CONSTANT_BUFFER_1, + AMDGPUAS::CONSTANT_BUFFER_2, AMDGPUAS::CONSTANT_BUFFER_3, + AMDGPUAS::CONSTANT_BUFFER_4, AMDGPUAS::CONSTANT_BUFFER_5, + AMDGPUAS::CONSTANT_BUFFER_6, AMDGPUAS::CONSTANT_BUFFER_7, + AMDGPUAS::CONSTANT_BUFFER_8, AMDGPUAS::CONSTANT_BUFFER_9, + AMDGPUAS::CONSTANT_BUFFER_10, AMDGPUAS::CONSTANT_BUFFER_11, + AMDGPUAS::CONSTANT_BUFFER_12, AMDGPUAS::CONSTANT_BUFFER_13, + AMDGPUAS::CONSTANT_BUFFER_14, AMDGPUAS::CONSTANT_BUFFER_15, + AMDGPUAS::CONSTANT_BUFFER_15, + }) { // TODO: find easier way to iterate all (relavent) addrspaces + + // There are no 64-bit extloads. These should be done as a 32-bit extload + // and an extension to 64-bit. + for (MVT VT : MVT::integer_valuetypes()) + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i64, + VT, Expand, AddrSpace); + + for (MVT VT : MVT::integer_valuetypes()) { + if (VT == MVT::i64) + continue; + + for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) { + setLoadExtAction(Op, VT, MVT::i1, Promote, AddrSpace); + setLoadExtAction(Op, VT, MVT::i8, Legal, AddrSpace); + setLoadExtAction(Op, VT, MVT::i16, Legal, AddrSpace); + setLoadExtAction(Op, VT, MVT::i32, Expand, AddrSpace); + } } - } - - for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) - for (auto MemVT : - {MVT::v2i8, MVT::v4i8, MVT::v2i16, MVT::v3i16, MVT::v4i16}) - setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MemVT, - Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand); - - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f32, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f32, Expand); - - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f16, Expand); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand); + for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) + for (auto MemVT : + {MVT::v2i8, MVT::v4i8, MVT::v2i16, MVT::v3i16, MVT::v4i16}) + setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, + MemVT, Expand, AddrSpace); + + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand, AddrSpace); + + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f32, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f32, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f32, Expand, AddrSpace); + + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f16, Expand, AddrSpace); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand, AddrSpace); + } setOperationAction(ISD::STORE, MVT::f32, Promote); AddPromotedToType(ISD::STORE, MVT::f32, MVT::i32); diff --git a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp index 2aa54c920a046..2aa9cb24f17ff 100644 --- a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp @@ -47,20 +47,38 @@ R600TargetLowering::R600TargetLowering(const TargetMachine &TM, // EXTLOAD should be the same as ZEXTLOAD. It is legal for some address // spaces, so it is custom lowered to handle those where it isn't. - for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) - for (MVT VT : MVT::integer_valuetypes()) { - setLoadExtAction(Op, VT, MVT::i1, Promote); - setLoadExtAction(Op, VT, MVT::i8, Custom); - setLoadExtAction(Op, VT, MVT::i16, Custom); - } - - // Workaround for LegalizeDAG asserting on expansion of i1 vector loads. - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i32, - MVT::v2i1, Expand); + for (unsigned AddrSpace : { + AMDGPUAS::MAX_AMDGPU_ADDRESS, AMDGPUAS::FLAT_ADDRESS, + AMDGPUAS::GLOBAL_ADDRESS, AMDGPUAS::REGION_ADDRESS, + AMDGPUAS::LOCAL_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS, + AMDGPUAS::PRIVATE_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS_32BIT, + AMDGPUAS::BUFFER_FAT_POINTER, AMDGPUAS::BUFFER_RESOURCE, + AMDGPUAS::BUFFER_STRIDED_POINTER, AMDGPUAS::STREAMOUT_REGISTER, + AMDGPUAS::PARAM_D_ADDRESS, AMDGPUAS::PARAM_I_ADDRESS, + + AMDGPUAS::CONSTANT_BUFFER_0, AMDGPUAS::CONSTANT_BUFFER_1, + AMDGPUAS::CONSTANT_BUFFER_2, AMDGPUAS::CONSTANT_BUFFER_3, + AMDGPUAS::CONSTANT_BUFFER_4, AMDGPUAS::CONSTANT_BUFFER_5, + AMDGPUAS::CONSTANT_BUFFER_6, AMDGPUAS::CONSTANT_BUFFER_7, + AMDGPUAS::CONSTANT_BUFFER_8, AMDGPUAS::CONSTANT_BUFFER_9, + AMDGPUAS::CONSTANT_BUFFER_10, AMDGPUAS::CONSTANT_BUFFER_11, + AMDGPUAS::CONSTANT_BUFFER_12, AMDGPUAS::CONSTANT_BUFFER_13, + AMDGPUAS::CONSTANT_BUFFER_14, AMDGPUAS::CONSTANT_BUFFER_15, + }) { // TODO: find easier way to iterate all (relavent) addrspaces + for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) + for (MVT VT : MVT::integer_valuetypes()) { + setLoadExtAction(Op, VT, MVT::i1, Promote, AddrSpace); + setLoadExtAction(Op, VT, MVT::i8, Custom, AddrSpace); + setLoadExtAction(Op, VT, MVT::i16, Custom, AddrSpace); + } - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v4i32, - MVT::v4i1, Expand); + // Workaround for LegalizeDAG asserting on expansion of i1 vector loads. + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i32, + MVT::v2i1, Expand, AddrSpace); + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v4i32, + MVT::v4i1, Expand, AddrSpace); + } setOperationAction(ISD::STORE, {MVT::i8, MVT::i32, MVT::v2i32, MVT::v4i32}, Custom); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 931a10b700c87..4a6192cba9345 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -60344,7 +60344,7 @@ static SDValue combineEXTEND_VECTOR_INREG(SDNode *N, SelectionDAG &DAG, ? ISD::SEXTLOAD : ISD::ZEXTLOAD; EVT MemVT = VT.changeVectorElementType(SVT); - if (TLI.isLoadExtLegal(Ext, VT, MemVT)) { + if (TLI.isLoadExtLegal(Ext, VT, MemVT, Ld->getAddressSpace())) { SDValue Load = DAG.getExtLoad( Ext, DL, VT, Ld->getChain(), Ld->getBasePtr(), Ld->getPointerInfo(), MemVT, Ld->getBaseAlign(), Ld->getMemOperand()->getFlags()); From 509137202adb2df286b33d7c20a4187ddf9323b4 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Wed, 8 Oct 2025 11:04:16 -0700 Subject: [PATCH 3/4] Make per address-space actions a sparse map overriding the original AS-independent ones instead --- llvm/include/llvm/CodeGen/TargetLowering.h | 56 ++++---- llvm/lib/CodeGen/TargetLoweringBase.cpp | 5 +- llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 134 ++++++++---------- llvm/lib/Target/AMDGPU/R600ISelLowering.cpp | 42 ++---- 4 files changed, 98 insertions(+), 139 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index a5af81aadc33f..f923ce5c4510e 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1482,16 +1482,14 @@ class LLVM_ABI TargetLoweringBase { MemI < MVT::VALUETYPE_SIZE && "Table isn't big enough!"); unsigned Shift = 4 * ExtType; - if (!LoadExtActions.count(AddrSpace)) { - if (MemVT == MVT::i2 || MemVT == MVT::i4 || MemVT == MVT::v128i2 || - MemVT == MVT::v64i4) - return Expand; + uint64_t OverrideKey = ((uint64_t)(ValI & 0xFF) << 40) | + ((uint64_t)(MemI & 0xFF) << 32) | + (uint64_t)AddrSpace; - return Legal; + if (LoadExtActionOverrides.count(OverrideKey)) { + return (LegalizeAction)((LoadExtActionOverrides.at(OverrideKey) >> Shift) & 0xf); } - return ( - LegalizeAction)((LoadExtActions.at(AddrSpace)[ValI][MemI] >> Shift) & - 0xf); + return (LegalizeAction)((LoadExtActions[ValI][MemI] >> Shift) & 0xf); } /// Return true if the specified load with extension is legal on this target. @@ -2645,36 +2643,36 @@ class LLVM_ABI TargetLoweringBase { /// Indicate that the specified load with extension does not work with the /// specified type and indicate what to do about it. void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT, - LegalizeAction Action, unsigned AddrSpace = 0) { + LegalizeAction Action, unsigned AddrSpace = ~0) { assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() && MemVT.isValid() && "Table isn't big enough!"); assert((unsigned)Action < 0x10 && "too many bits for bitfield array"); - if (!LoadExtActions.count(AddrSpace)) { - LoadExtActions[AddrSpace]; // Initialize the map for the addrspace + unsigned Shift = 4 * ExtType; - for (MVT AVT : MVT::all_valuetypes()) { - for (MVT VT : {MVT::i2, MVT::i4, MVT::v128i2, MVT::v64i4}) { - setLoadExtAction(ISD::EXTLOAD, AVT, VT, Expand, AddrSpace); - setLoadExtAction(ISD::ZEXTLOAD, AVT, VT, Expand, AddrSpace); - } - } + if (AddrSpace == ~((unsigned)0)) { + LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &= + ~((uint16_t)0xF << Shift); + LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action + << Shift; + } else { + uint64_t OverrideKey = ((uint64_t)(ValVT.SimpleTy & 0xFF) << 40) | + ((uint64_t)(MemVT.SimpleTy & 0xFF) << 32) | + (uint64_t)AddrSpace; + uint16_t &OverrideVal = LoadExtActionOverrides[OverrideKey]; + + OverrideVal &= ~((uint16_t)0xF << Shift); + OverrideVal |= (uint16_t)Action << Shift; } - - unsigned Shift = 4 * ExtType; - LoadExtActions[AddrSpace][ValVT.SimpleTy][MemVT.SimpleTy] &= - ~((uint16_t)0xF << Shift); - LoadExtActions[AddrSpace][ValVT.SimpleTy][MemVT.SimpleTy] |= - (uint16_t)Action << Shift; } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, MVT MemVT, - LegalizeAction Action, unsigned AddrSpace = 0) { + LegalizeAction Action, unsigned AddrSpace = ~0) { for (auto ExtType : ExtTypes) setLoadExtAction(ExtType, ValVT, MemVT, Action, AddrSpace); } void setLoadExtAction(ArrayRef ExtTypes, MVT ValVT, ArrayRef MemVTs, LegalizeAction Action, - unsigned AddrSpace = 0) { + unsigned AddrSpace = ~0) { for (auto MemVT : MemVTs) setLoadExtAction(ExtTypes, ValVT, MemVT, Action, AddrSpace); } @@ -3781,11 +3779,9 @@ class LLVM_ABI TargetLoweringBase { /// specific value type and extension type. Uses 4-bits to store the action /// for each of the 4 load ext types. These actions can be specified for each /// address space. - using LoadExtActionMapTy = - std::array, - MVT::VALUETYPE_SIZE>; - using LoadExtActionMap = std::map; - LoadExtActionMap LoadExtActions; + uint16_t LoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE]; + using LoadExtActionOverrideMap = std::map; + LoadExtActionOverrideMap LoadExtActionOverrides; /// Similar to LoadExtActions, but for atomic loads. Only Legal or Expand /// (default) values are supported. diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index 3addb58c06f8f..1343ffe1db70e 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -728,7 +728,8 @@ TargetLoweringBase::~TargetLoweringBase() = default; void TargetLoweringBase::initActions() { // All operations default to being supported. memset(OpActions, 0, sizeof(OpActions)); - LoadExtActions.clear(); + memset(LoadExtActions, 0, sizeof(LoadExtActions)); + LoadExtActionOverrides.clear(); memset(TruncStoreActions, 0, sizeof(TruncStoreActions)); memset(IndexedModeActions, 0, sizeof(IndexedModeActions)); memset(CondCodeActions, 0, sizeof(CondCodeActions)); @@ -751,6 +752,8 @@ void TargetLoweringBase::initActions() { for (MVT AVT : MVT::all_valuetypes()) { for (MVT VT : {MVT::i2, MVT::i4, MVT::v128i2, MVT::v64i4}) { setTruncStoreAction(AVT, VT, Expand); + setLoadExtAction(ISD::EXTLOAD, AVT, VT, Expand); + setLoadExtAction(ISD::ZEXTLOAD, AVT, VT, Expand); } } for (unsigned IM = (unsigned)ISD::PRE_INC; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index 3aa8e4602b497..a44af5f854c18 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -23,7 +23,6 @@ #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IntrinsicsAMDGPU.h" -#include "llvm/Support/AMDGPUAddrSpace.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Target/TargetMachine.h" @@ -179,86 +178,65 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, setOperationAction(ISD::ATOMIC_STORE, MVT::bf16, Promote); AddPromotedToType(ISD::ATOMIC_STORE, MVT::bf16, MVT::i16); - for (unsigned AddrSpace : { - AMDGPUAS::MAX_AMDGPU_ADDRESS, AMDGPUAS::FLAT_ADDRESS, - AMDGPUAS::GLOBAL_ADDRESS, AMDGPUAS::REGION_ADDRESS, - AMDGPUAS::LOCAL_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS, - AMDGPUAS::PRIVATE_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS_32BIT, - AMDGPUAS::BUFFER_FAT_POINTER, AMDGPUAS::BUFFER_RESOURCE, - AMDGPUAS::BUFFER_STRIDED_POINTER, AMDGPUAS::STREAMOUT_REGISTER, - AMDGPUAS::PARAM_D_ADDRESS, AMDGPUAS::PARAM_I_ADDRESS, - - AMDGPUAS::CONSTANT_BUFFER_0, AMDGPUAS::CONSTANT_BUFFER_1, - AMDGPUAS::CONSTANT_BUFFER_2, AMDGPUAS::CONSTANT_BUFFER_3, - AMDGPUAS::CONSTANT_BUFFER_4, AMDGPUAS::CONSTANT_BUFFER_5, - AMDGPUAS::CONSTANT_BUFFER_6, AMDGPUAS::CONSTANT_BUFFER_7, - AMDGPUAS::CONSTANT_BUFFER_8, AMDGPUAS::CONSTANT_BUFFER_9, - AMDGPUAS::CONSTANT_BUFFER_10, AMDGPUAS::CONSTANT_BUFFER_11, - AMDGPUAS::CONSTANT_BUFFER_12, AMDGPUAS::CONSTANT_BUFFER_13, - AMDGPUAS::CONSTANT_BUFFER_14, AMDGPUAS::CONSTANT_BUFFER_15, - AMDGPUAS::CONSTANT_BUFFER_15, - }) { // TODO: find easier way to iterate all (relavent) addrspaces - - // There are no 64-bit extloads. These should be done as a 32-bit extload - // and an extension to 64-bit. - for (MVT VT : MVT::integer_valuetypes()) - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i64, - VT, Expand, AddrSpace); - - for (MVT VT : MVT::integer_valuetypes()) { - if (VT == MVT::i64) - continue; - - for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) { - setLoadExtAction(Op, VT, MVT::i1, Promote, AddrSpace); - setLoadExtAction(Op, VT, MVT::i8, Legal, AddrSpace); - setLoadExtAction(Op, VT, MVT::i16, Legal, AddrSpace); - setLoadExtAction(Op, VT, MVT::i32, Expand, AddrSpace); - } - } + // There are no 64-bit extloads. These should be done as a 32-bit extload and + // an extension to 64-bit. + for (MVT VT : MVT::integer_valuetypes()) + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i64, VT, + Expand); + + for (MVT VT : MVT::integer_valuetypes()) { + if (VT == MVT::i64) + continue; - for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) - for (auto MemVT : - {MVT::v2i8, MVT::v4i8, MVT::v2i16, MVT::v3i16, MVT::v4i16}) - setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, - MemVT, Expand, AddrSpace); - - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand, AddrSpace); - - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f32, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f32, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f32, Expand, AddrSpace); - - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f16, Expand, AddrSpace); - setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand, AddrSpace); + for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) { + setLoadExtAction(Op, VT, MVT::i1, Promote); + setLoadExtAction(Op, VT, MVT::i8, Legal); + setLoadExtAction(Op, VT, MVT::i16, Legal); + setLoadExtAction(Op, VT, MVT::i32, Expand); + } } + for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) + for (auto MemVT : + {MVT::v2i8, MVT::v4i8, MVT::v2i16, MVT::v3i16, MVT::v4i16}) + setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MemVT, + Expand); + + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand); + + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f32, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f32, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f32, Expand); + + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand); + setOperationAction(ISD::STORE, MVT::f32, Promote); AddPromotedToType(ISD::STORE, MVT::f32, MVT::i32); diff --git a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp index 2aa9cb24f17ff..2aa54c920a046 100644 --- a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp @@ -47,38 +47,20 @@ R600TargetLowering::R600TargetLowering(const TargetMachine &TM, // EXTLOAD should be the same as ZEXTLOAD. It is legal for some address // spaces, so it is custom lowered to handle those where it isn't. - for (unsigned AddrSpace : { - AMDGPUAS::MAX_AMDGPU_ADDRESS, AMDGPUAS::FLAT_ADDRESS, - AMDGPUAS::GLOBAL_ADDRESS, AMDGPUAS::REGION_ADDRESS, - AMDGPUAS::LOCAL_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS, - AMDGPUAS::PRIVATE_ADDRESS, AMDGPUAS::CONSTANT_ADDRESS_32BIT, - AMDGPUAS::BUFFER_FAT_POINTER, AMDGPUAS::BUFFER_RESOURCE, - AMDGPUAS::BUFFER_STRIDED_POINTER, AMDGPUAS::STREAMOUT_REGISTER, - AMDGPUAS::PARAM_D_ADDRESS, AMDGPUAS::PARAM_I_ADDRESS, - - AMDGPUAS::CONSTANT_BUFFER_0, AMDGPUAS::CONSTANT_BUFFER_1, - AMDGPUAS::CONSTANT_BUFFER_2, AMDGPUAS::CONSTANT_BUFFER_3, - AMDGPUAS::CONSTANT_BUFFER_4, AMDGPUAS::CONSTANT_BUFFER_5, - AMDGPUAS::CONSTANT_BUFFER_6, AMDGPUAS::CONSTANT_BUFFER_7, - AMDGPUAS::CONSTANT_BUFFER_8, AMDGPUAS::CONSTANT_BUFFER_9, - AMDGPUAS::CONSTANT_BUFFER_10, AMDGPUAS::CONSTANT_BUFFER_11, - AMDGPUAS::CONSTANT_BUFFER_12, AMDGPUAS::CONSTANT_BUFFER_13, - AMDGPUAS::CONSTANT_BUFFER_14, AMDGPUAS::CONSTANT_BUFFER_15, - }) { // TODO: find easier way to iterate all (relavent) addrspaces - for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) - for (MVT VT : MVT::integer_valuetypes()) { - setLoadExtAction(Op, VT, MVT::i1, Promote, AddrSpace); - setLoadExtAction(Op, VT, MVT::i8, Custom, AddrSpace); - setLoadExtAction(Op, VT, MVT::i16, Custom, AddrSpace); - } + for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) + for (MVT VT : MVT::integer_valuetypes()) { + setLoadExtAction(Op, VT, MVT::i1, Promote); + setLoadExtAction(Op, VT, MVT::i8, Custom); + setLoadExtAction(Op, VT, MVT::i16, Custom); + } - // Workaround for LegalizeDAG asserting on expansion of i1 vector loads. - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i32, - MVT::v2i1, Expand, AddrSpace); + // Workaround for LegalizeDAG asserting on expansion of i1 vector loads. + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i32, + MVT::v2i1, Expand); + + setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v4i32, + MVT::v4i1, Expand); - setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v4i32, - MVT::v4i1, Expand, AddrSpace); - } setOperationAction(ISD::STORE, {MVT::i8, MVT::i32, MVT::v2i32, MVT::v4i32}, Custom); From 8f51e32118dbc44de7c06a8c28f202042b2d5931 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Wed, 8 Oct 2025 12:38:55 -0700 Subject: [PATCH 4/4] Touch up formatting --- llvm/include/llvm/CodeGen/BasicTTIImpl.h | 6 ++- llvm/include/llvm/CodeGen/TargetLowering.h | 6 ++- llvm/lib/CodeGen/CodeGenPrepare.cpp | 3 +- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 38 ++++++++++--------- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 5 ++- .../SelectionDAG/LegalizeVectorOps.cpp | 3 +- 6 files changed, 35 insertions(+), 26 deletions(-) diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index af087b154c7f7..1439683dc5e96 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1250,7 +1250,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { auto *LI = cast(I->getOperand(0)); if (DstLT.first == SrcLT.first && - TLI->isLoadExtLegal(LType, ExtVT, LoadVT, LI->getPointerAddressSpace())) + TLI->isLoadExtLegal(LType, ExtVT, LoadVT, + LI->getPointerAddressSpace())) return 0; } } @@ -1537,7 +1538,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { if (Opcode == Instruction::Store) LA = getTLI()->getTruncStoreAction(LT.second, MemVT); else - LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT, AddressSpace); + LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT, + AddressSpace); if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) { // This is a vector load/store for some illegal type that is scalarized. diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index f923ce5c4510e..0b160443ade59 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1493,13 +1493,15 @@ class LLVM_ABI TargetLoweringBase { } /// Return true if the specified load with extension is legal on this target. - bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT, unsigned AddrSpace) const { + bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT, + unsigned AddrSpace) const { return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal; } /// Return true if the specified load with extension is legal or custom /// on this target. - bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT, unsigned AddrSpace) const { + bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT, + unsigned AddrSpace) const { return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal || getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Custom; } diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 1bcbc64f3105b..0b7ddf1211f54 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -7347,7 +7347,8 @@ bool CodeGenPrepare::optimizeLoadExt(LoadInst *Load) { // Reject cases that won't be matched as extloads. if (!LoadResultVT.bitsGT(TruncVT) || !TruncVT.isRound() || - !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT, Load->getPointerAddressSpace())) + !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT, + Load->getPointerAddressSpace())) return false; IRBuilder<> Builder(Load->getNextNode()); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b0519302adc34..8a8de4ba97f92 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6889,7 +6889,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (ExtVT == LoadedVT && (!LegalOperations || - TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, LoadN->getAddressSpace()))) { + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, + LoadN->getAddressSpace()))) { // ZEXTLOAD will match without needing to change the size of the value being // loaded. return true; @@ -6904,8 +6905,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound()) return false; - if (LegalOperations && - !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, LoadN->getAddressSpace())) + if (LegalOperations && !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, + LoadN->getAddressSpace())) return false; if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0)) @@ -6967,8 +6968,8 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST, if (!SDValue(Load, 0).hasOneUse()) return false; - if (LegalOperations && - !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT, Load->getAddressSpace())) + if (LegalOperations && !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT, + Load->getAddressSpace())) return false; // For the transform to be legal, the load must produce only two values @@ -7480,7 +7481,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; - if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT, MLoad->getAddressSpace())) { + if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT, + MLoad->getAddressSpace())) { // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type @@ -7631,10 +7633,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is // actually legal and isn't going to get expanded, else this is a false // optimisation. - bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD, - Load->getValueType(0), - Load->getMemoryVT(), - Load->getAddressSpace()); + bool CanZextLoadProfitably = + TLI.isLoadExtLegal(ISD::ZEXTLOAD, Load->getValueType(0), + Load->getMemoryVT(), Load->getAddressSpace()); // Resize the constant to the same size as the original memory access before // extension. If it is still the AllOnesValue then this AND is completely @@ -14296,8 +14297,7 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner, LoadSDNode *LN0 = cast(N0); EVT MemVT = LN0->getMemoryVT(); - if ((LegalOperations || !LN0->isSimple() || - VT.isVector()) && + if ((LegalOperations || !LN0->isSimple() || VT.isVector()) && !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT, LN0->getAddressSpace())) return SDValue(); @@ -14344,9 +14344,9 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner, // TODO: isFixedLengthVector() should be removed and any negative effects on // code generation being the result of that target's implementation of // isVectorLoadExtDesirable(). - if ((LegalOperations || VT.isFixedLengthVector() || - !LN0->isSimple()) && - !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(), LN0->getAddressSpace())) + if ((LegalOperations || VT.isFixedLengthVector() || !LN0->isSimple()) && + !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(), + LN0->getAddressSpace())) return {}; bool DoXform = true; @@ -14388,7 +14388,8 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT, return SDValue(); if ((LegalOperations || !Ld->isSimple()) && - !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0), Ld->getAddressSpace())) + !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0), + Ld->getAddressSpace())) return SDValue(); if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0))) @@ -14731,7 +14732,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { LoadSDNode *LN00 = cast(N0.getOperand(0)); EVT MemVT = LN00->getMemoryVT(); if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT, LN00->getAddressSpace()) && - LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { + LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) { SmallVector SetCCs; bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0), ISD::SIGN_EXTEND, SetCCs, TLI); @@ -15316,7 +15317,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { LoadSDNode *LN0 = cast(N0); ISD::LoadExtType ExtType = LN0->getExtensionType(); EVT MemVT = LN0->getMemoryVT(); - if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT, LN0->getAddressSpace())) { + if (!LegalOperations || + TLI.isLoadExtLegal(ExtType, VT, MemVT, LN0->getAddressSpace())) { SDValue ExtLoad = DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(), MemVT, LN0->getMemOperand()); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index 99dcf23e9b121..532a8c490b481 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -742,7 +742,7 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) { (SrcVT != MVT::i1 || TLI.getLoadExtAction(ExtType, Node->getValueType(0), MVT::i1, LD->getAddressSpace()) == - TargetLowering::Promote)) { + TargetLowering::Promote)) { // Promote to a byte-sized load if not loading an integral number of // bytes. For example, promote EXTLOAD:i20 -> EXTLOAD:i24. unsigned NewWidth = SrcVT.getStoreSizeInBits(); @@ -1849,7 +1849,8 @@ SDValue SelectionDAGLegalize::EmitStackConvert(SDValue SrcOp, EVT SlotVT, if ((SrcVT.bitsGT(SlotVT) && !TLI.isTruncStoreLegalOrCustom(SrcOp.getValueType(), SlotVT)) || (SlotVT.bitsLT(DestVT) && - !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT, DAG.getDataLayout().getAllocaAddrSpace()))) + !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT, + DAG.getDataLayout().getAllocaAddrSpace()))) return SDValue(); // Create the stack frame object. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index a25705235cb40..be8e780a6f55d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -301,7 +301,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { ISD::LoadExtType ExtType = LD->getExtensionType(); EVT LoadedVT = LD->getMemoryVT(); if (LoadedVT.isVector() && ExtType != ISD::NON_EXTLOAD) - Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT, LD->getAddressSpace()); + Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT, + LD->getAddressSpace()); break; } case ISD::STORE: {