Skip to content

Conversation

QuantumSegfault
Copy link

@QuantumSegfault QuantumSegfault commented Oct 8, 2025

Reboot of #157627

Makes TargetLoweringBase::getLoadExtAction, TargetLoweringBase::setLoadExtAction, TargetLoweringBase::isLoadExtLegal and TargetLoweringBaes::isLoadExtLegalOrCustom dependent on address space, rather than applying all-or-nothing logic to all address spaces.

Copy link

github-actions bot commented Oct 8, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@QuantumSegfault QuantumSegfault changed the title [CodeGen] Convert LoadExtActions to a map indexed by AddrSpace [CodeGen] Make LoadExtActions address-space aware Oct 8, 2025
@QuantumSegfault
Copy link
Author

QuantumSegfault commented Oct 8, 2025

@dschuff

Alright, so I'm close to finishing @sparker-arm started. I made sure it doesn't introduce any regressions in the AMDGPU or WebAssembly backends. I imagine NVPTX will also need adjusting, but I haven't gotten around to compiling all the rest of LLVM yet.

However, I'm not quite satisfied with it as is. Currently there's a default parameter on setLoadExtAction to avoid needing to modify every single backend. However it's more a footgun than anything, as the default behavior has still changed.

I feel it makes more sense to remove the parameter, but that then touches virtually every backend, and requires everybody to learn the new behavior.

OR: We change the tactic a little. Make the map an optional override. We add a new LegalizeAction::Unspecified. New address spaces put into the override map are filled with Unspecified (I'm thinking 0xF?). During look up, it checks the map first, then falls back to the old/existing LoadExtActions if seeing Unspecified. Probably have a default AddrSpace of ~0 on setLoadExtAction, meaning modify it for all address-spaces (the fallback mapping).

At that rate, might as well make the map a sparse (AddrSpace, ValVT, MemVT) => LegalizeAction map, since it will be rarely used. Could probably even pack that key in a single uint64_t...or would a std::pair<unsigned, uint16_t> be preferrable? Sparse avoids the need for Unspecified then...

That should allow leaving the backends as is, making the feature opt-in.

Copy link

github-actions bot commented Oct 8, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- llvm/include/llvm/CodeGen/BasicTTIImpl.h llvm/include/llvm/CodeGen/TargetLowering.h llvm/lib/CodeGen/CodeGenPrepare.cpp llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp llvm/lib/CodeGen/TargetLoweringBase.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.cpp llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp llvm/lib/Target/AMDGPU/R600ISelLowering.cpp llvm/lib/Target/X86/X86ISelLowering.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index af087b154..1439683dc 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1250,7 +1250,8 @@ public:
           auto *LI = cast<LoadInst>(I->getOperand(0));
 
           if (DstLT.first == SrcLT.first &&
-              TLI->isLoadExtLegal(LType, ExtVT, LoadVT, LI->getPointerAddressSpace()))
+              TLI->isLoadExtLegal(LType, ExtVT, LoadVT,
+                                  LI->getPointerAddressSpace()))
             return 0;
         }
       }
@@ -1537,7 +1538,8 @@ public:
       if (Opcode == Instruction::Store)
         LA = getTLI()->getTruncStoreAction(LT.second, MemVT);
       else
-        LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT, AddressSpace);
+        LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT,
+                                        AddressSpace);
 
       if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) {
         // This is a vector load/store for some illegal type that is scalarized.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a5af81aad..46ce904fb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1495,13 +1495,15 @@ public:
   }
 
   /// Return true if the specified load with extension is legal on this target.
-  bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT, unsigned AddrSpace) const {
+  bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT,
+                      unsigned AddrSpace) const {
     return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal;
   }
 
   /// Return true if the specified load with extension is legal or custom
   /// on this target.
-  bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT, unsigned AddrSpace) const {
+  bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT,
+                              unsigned AddrSpace) const {
     return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal ||
            getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Custom;
   }
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 1bcbc64f3..0b7ddf121 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -7347,7 +7347,8 @@ bool CodeGenPrepare::optimizeLoadExt(LoadInst *Load) {
 
   // Reject cases that won't be matched as extloads.
   if (!LoadResultVT.bitsGT(TruncVT) || !TruncVT.isRound() ||
-      !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT, Load->getPointerAddressSpace()))
+      !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT,
+                           Load->getPointerAddressSpace()))
     return false;
 
   IRBuilder<> Builder(Load->getNextNode());
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6ece42e72..7b77bdc4b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -6889,7 +6889,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
 
   if (ExtVT == LoadedVT &&
       (!LegalOperations ||
-       TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, LoadN->getAddressSpace()))) {
+       TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT,
+                          LoadN->getAddressSpace()))) {
     // ZEXTLOAD will match without needing to change the size of the value being
     // loaded.
     return true;
@@ -6904,8 +6905,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
     return false;
 
-  if (LegalOperations &&
-      !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT, LoadN->getAddressSpace()))
+  if (LegalOperations && !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT,
+                                             LoadN->getAddressSpace()))
     return false;
 
   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0))
@@ -6967,8 +6968,8 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
     if (!SDValue(Load, 0).hasOneUse())
       return false;
 
-    if (LegalOperations &&
-        !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT, Load->getAddressSpace()))
+    if (LegalOperations && !TLI.isLoadExtLegal(ExtType, Load->getValueType(0),
+                                               MemVT, Load->getAddressSpace()))
       return false;
 
     // For the transform to be legal, the load must produce only two values
@@ -7480,7 +7481,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
       EVT LoadVT = MLoad->getMemoryVT();
       EVT ExtVT = VT;
-      if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT, MLoad->getAddressSpace())) {
+      if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT,
+                             MLoad->getAddressSpace())) {
         // For this AND to be a zero extension of the masked load the elements
         // of the BuildVec must mask the bottom bits of the extended element
         // type
@@ -7631,10 +7633,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
     // actually legal and isn't going to get expanded, else this is a false
     // optimisation.
-    bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
-                                                    Load->getValueType(0),
-                                                    Load->getMemoryVT(),
-                                                    Load->getAddressSpace());
+    bool CanZextLoadProfitably =
+        TLI.isLoadExtLegal(ISD::ZEXTLOAD, Load->getValueType(0),
+                           Load->getMemoryVT(), Load->getAddressSpace());
 
     // Resize the constant to the same size as the original memory access before
     // extension. If it is still the AllOnesValue then this AND is completely
@@ -7826,7 +7827,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
         ((!LegalOperations && LN0->isSimple()) ||
-         TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN0->getAddressSpace()))) {
+         TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT,
+                            LN0->getAddressSpace()))) {
       SDValue ExtLoad =
           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
@@ -14296,8 +14298,7 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
 
   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   EVT MemVT = LN0->getMemoryVT();
-  if ((LegalOperations || !LN0->isSimple() ||
-       VT.isVector()) &&
+  if ((LegalOperations || !LN0->isSimple() || VT.isVector()) &&
       !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT, LN0->getAddressSpace()))
     return SDValue();
 
@@ -14344,9 +14345,9 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
   // TODO: isFixedLengthVector() should be removed and any negative effects on
   // code generation being the result of that target's implementation of
   // isVectorLoadExtDesirable().
-  if ((LegalOperations || VT.isFixedLengthVector() ||
-       !LN0->isSimple()) &&
-      !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(), LN0->getAddressSpace()))
+  if ((LegalOperations || VT.isFixedLengthVector() || !LN0->isSimple()) &&
+      !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(),
+                          LN0->getAddressSpace()))
     return {};
 
   bool DoXform = true;
@@ -14388,7 +14389,8 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
     return SDValue();
 
   if ((LegalOperations || !Ld->isSimple()) &&
-      !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0), Ld->getAddressSpace()))
+      !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0),
+                                  Ld->getAddressSpace()))
     return SDValue();
 
   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
@@ -14731,7 +14733,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
     EVT MemVT = LN00->getMemoryVT();
     if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT, LN00->getAddressSpace()) &&
-      LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
+        LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
       SmallVector<SDNode*, 4> SetCCs;
       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
                                              ISD::SIGN_EXTEND, SetCCs, TLI);
@@ -15316,7 +15318,8 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
     ISD::LoadExtType ExtType = LN0->getExtensionType();
     EVT MemVT = LN0->getMemoryVT();
-    if (!LegalOperations || TLI.isLoadExtLegal(ExtType, VT, MemVT, LN0->getAddressSpace())) {
+    if (!LegalOperations ||
+        TLI.isLoadExtLegal(ExtType, VT, MemVT, LN0->getAddressSpace())) {
       SDValue ExtLoad =
           DAG.getExtLoad(ExtType, DL, VT, LN0->getChain(), LN0->getBasePtr(),
                          MemVT, LN0->getMemOperand());
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 99dcf23e9..532a8c490 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -742,7 +742,7 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
       (SrcVT != MVT::i1 ||
        TLI.getLoadExtAction(ExtType, Node->getValueType(0), MVT::i1,
                             LD->getAddressSpace()) ==
-         TargetLowering::Promote)) {
+           TargetLowering::Promote)) {
     // Promote to a byte-sized load if not loading an integral number of
     // bytes.  For example, promote EXTLOAD:i20 -> EXTLOAD:i24.
     unsigned NewWidth = SrcVT.getStoreSizeInBits();
@@ -1849,7 +1849,8 @@ SDValue SelectionDAGLegalize::EmitStackConvert(SDValue SrcOp, EVT SlotVT,
   if ((SrcVT.bitsGT(SlotVT) &&
        !TLI.isTruncStoreLegalOrCustom(SrcOp.getValueType(), SlotVT)) ||
       (SlotVT.bitsLT(DestVT) &&
-       !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT, DAG.getDataLayout().getAllocaAddrSpace())))
+       !TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, DestVT, SlotVT,
+                                   DAG.getDataLayout().getAllocaAddrSpace())))
     return SDValue();
 
   // Create the stack frame object.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index a25705235..be8e780a6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -301,7 +301,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
     ISD::LoadExtType ExtType = LD->getExtensionType();
     EVT LoadedVT = LD->getMemoryVT();
     if (LoadedVT.isVector() && ExtType != ISD::NON_EXTLOAD)
-      Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT, LD->getAddressSpace());
+      Action = TLI.getLoadExtAction(ExtType, LD->getValueType(0), LoadedVT,
+                                    LD->getAddressSpace());
     break;
   }
   case ISD::STORE: {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 3aa8e4602..92504762c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -234,9 +234,11 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
     setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand, AddrSpace);
     setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand, AddrSpace);
     setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand, AddrSpace);
-    setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand, AddrSpace);
+    setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16bf16, Expand,
+                     AddrSpace);
     setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32f16, Expand, AddrSpace);
-    setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand, AddrSpace);
+    setLoadExtAction(ISD::EXTLOAD, MVT::v32f32, MVT::v32bf16, Expand,
+                     AddrSpace);
 
     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand, AddrSpace);
     setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand, AddrSpace);
@@ -256,7 +258,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
     setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand, AddrSpace);
     setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand, AddrSpace);
     setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f16, Expand, AddrSpace);
-    setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand, AddrSpace);
+    setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16bf16, Expand,
+                     AddrSpace);
   }
 
   setOperationAction(ISD::STORE, MVT::f32, Promote);

@dschuff
Copy link
Member

dschuff commented Oct 8, 2025

So if I'm understanding you, then setLoadExtAction basically becomes overloaded rather than having a default argument, with the existing overload getting the existing behavior and the addrspace version overriding it. That also seems pretty reasonable, although it does make the API a bit more complex.

Let's also not completely throw out the solution of just changing the behavior without the default parameter. This has the advantage that it's the simplest API. There is the up-front cost to change the backends (this can maybe be partly automated; but it also includes churn imposed on downstream backend maintainers, which needn't be an automatic dealbreaker but also isn't nothing). it also has the cost that future uses will have to include the argument even if they don't care about address spaces, but this doesn't seem so bad because the behavior is otherwise the same so this shouldn't be hard to figure out.

Going with the default argument means less churn, but as you pointed out it's a silent change in behavior. (I started the CI for the PR, so we'll see which other targets might need updating in this version).

Churn is a one-time cost, but additional API complexity is an ongoing cost (and in that model maybe you'd consider a silent behavior change to be a one-time cost but which is spread out over time as people discover it, so probably worse than an equivalent all-at-once change). Of course you can argue all day about how bad extra API complexity is.

I wonder, is the additional complexity from the overload version buying us anything beyond avoiding one-time churn? Like, are there any targets that could use both versions together, or would they all just use one or the other?

@QuantumSegfault
Copy link
Author

Besides avoiding touching the backends, the overrides idea makes it easier when a given backend DOES want to override ALL addressspaces (like AMDGPU). Avoids having to loop over all (relevant) address-spaces just to set the same behavior (see my changes to the AMDGPU backend... :/)

Like, are there any targets that could use both versions together, or would they all just use one or the other?

I'm not entirely sure. There doesn't seem to be many backends making heavy use of address spaces; mainly NVPTX, AMDGPU, and WebAssembly.

I did notice this when fixing up AMDGPU:

// EXTLOAD should be the same as ZEXTLOAD. It is legal for some address
// spaces, so it is custom lowered to handle those where it isn't.
for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD})
for (MVT VT : MVT::integer_valuetypes()) {
setLoadExtAction(Op, VT, MVT::i1, Promote);
setLoadExtAction(Op, VT, MVT::i8, Custom);
setLoadExtAction(Op, VT, MVT::i16, Custom);
}

Main AMDGPU sets these i8 and i16 extending loads legal, but then the R600 subtarget marks them custom. The specific address spaces needing Custom handling could be marked as such with the new override mechanism, while AMDGPU as a whole keeps the original API. Though whether this would provide any tangible benefit is unclear. It may allow some better optimization in some cases that are being blocked right now by unnecessary "Custom"? The very reason we went down this rabbit hole to begin with.


Also, a override mechanism would prevent this from needing to be put in here (this was moved from TargetLoweringBase::initActions)

if (!LoadExtActions.count(AddrSpace)) {
LoadExtActions[AddrSpace]; // Initialize the map for the addrspace
for (MVT AVT : MVT::all_valuetypes()) {
for (MVT VT : {MVT::i2, MVT::i4, MVT::v128i2, MVT::v64i4}) {
setLoadExtAction(ISD::EXTLOAD, AVT, VT, Expand, AddrSpace);
setLoadExtAction(ISD::ZEXTLOAD, AVT, VT, Expand, AddrSpace);
}
}
}


I'm currently putting the finishing touches on the overrides concept, so I'll push that shortly and you can see which seems more ideal.

@QuantumSegfault
Copy link
Author

QuantumSegfault commented Oct 8, 2025

Alright. Option 2 is done.

So between the two, which seems better?


As far as the CI/tests go, NVPTX seems to be the main failing point, as expected.

Though...
Analysis/CostModel/AArch64/masked_ldst.ll
CodeGen/PowerPC/pr48519.ll

Need to investigate

EDIT:
The PowerPC one is a load from an alternate addr space, so that just needs to be configure

As far as AArch64, I'm not quite sure
CHECK-NEXT: Cost Model: Found costs of 0 for: %zext.nxv8i8to16 = zext <vscale x 8 x i8> %load.nxv8i8 to <vscale x 8 x i16>
But it "Found costs of 1" instead

@dschuff
Copy link
Member

dschuff commented Oct 8, 2025

OK, let's get @arsenm's opinion then, from both AMDGPU and general codgen perspective.

@QuantumSegfault QuantumSegfault marked this pull request as ready for review October 9, 2025 07:03
@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2025

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: Demetrius Kanios (QuantumSegfault)

Changes

Reboot of #157627

Makes TargetLoweringBase::getLoadExtAction, TargetLoweringBase::setLoadExtAction, TargetLoweringBase::isLoadExtLegal and TargetLoweringBaes::isLoadExtLegalOrCustom dependent on address space, rather than applying all-or-nothing logic to all address spaces.


Patch is 32.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162407.diff

10 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+12-4)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+48-19)
  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+70-46)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+11-7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+2-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42ddb32d24093..1439683dc5e96 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -28,6 +28,7 @@
 #include "llvm/Analysis/TargetTransformInfoImpl.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/TargetSubtargetInfo.h"
 #include "llvm/CodeGen/ValueTypes.h"
@@ -1244,9 +1245,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         EVT LoadVT = EVT::getEVT(Src);
         unsigned LType =
           ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD);
-        if (DstLT.first == SrcLT.first &&
-            TLI->isLoadExtLegal(LType, ExtVT, LoadVT))
-          return 0;
+
+        if (I && isa<LoadInst>(I->getOperand(0))) {
+          auto *LI = cast<LoadInst>(I->getOperand(0));
+
+          if (DstLT.first == SrcLT.first &&
+              TLI->isLoadExtLegal(LType, ExtVT, LoadVT,
+                                  LI->getPointerAddressSpace()))
+            return 0;
+        }
       }
       break;
     case Instruction::AddrSpaceCast:
@@ -1531,7 +1538,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       if (Opcode == Instruction::Store)
         LA = getTLI()->getTruncStoreAction(LT.second, MemVT);
       else
-        LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT);
+        LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT,
+                                        AddressSpace);
 
       if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) {
         // This is a vector load/store for some illegal type that is scalarized.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 88691b931a8d8..0b160443ade59 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1472,27 +1472,38 @@ class LLVM_ABI TargetLoweringBase {
   /// Return how this load with extension should be treated: either it is legal,
   /// needs to be promoted to a larger size, needs to be expanded to some other
   /// code sequence, or the target has a custom expander for it.
-  LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT,
-                                  EVT MemVT) const {
-    if (ValVT.isExtended() || MemVT.isExtended()) return Expand;
-    unsigned ValI = (unsigned) ValVT.getSimpleVT().SimpleTy;
-    unsigned MemI = (unsigned) MemVT.getSimpleVT().SimpleTy;
+  LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT, EVT MemVT,
+                                  unsigned AddrSpace) const {
+    if (ValVT.isExtended() || MemVT.isExtended())
+      return Expand;
+    unsigned ValI = (unsigned)ValVT.getSimpleVT().SimpleTy;
+    unsigned MemI = (unsigned)MemVT.getSimpleVT().SimpleTy;
     assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValI < MVT::VALUETYPE_SIZE &&
            MemI < MVT::VALUETYPE_SIZE && "Table isn't big enough!");
     unsigned Shift = 4 * ExtType;
+
+    uint64_t OverrideKey = ((uint64_t)(ValI & 0xFF) << 40) |
+                           ((uint64_t)(MemI & 0xFF) << 32) |
+                           (uint64_t)AddrSpace;
+
+    if (LoadExtActionOverrides.count(OverrideKey)) {
+      return (LegalizeAction)((LoadExtActionOverrides.at(OverrideKey) >> Shift) & 0xf);
+    }
     return (LegalizeAction)((LoadExtActions[ValI][MemI] >> Shift) & 0xf);
   }
 
   /// Return true if the specified load with extension is legal on this target.
-  bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT) const {
-    return getLoadExtAction(ExtType, ValVT, MemVT) == Legal;
+  bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT,
+                      unsigned AddrSpace) const {
+    return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal;
   }
 
   /// Return true if the specified load with extension is legal or custom
   /// on this target.
-  bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT) const {
-    return getLoadExtAction(ExtType, ValVT, MemVT) == Legal ||
-           getLoadExtAction(ExtType, ValVT, MemVT) == Custom;
+  bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT,
+                              unsigned AddrSpace) const {
+    return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal ||
+           getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Custom;
   }
 
   /// Same as getLoadExtAction, but for atomic loads.
@@ -2634,23 +2645,38 @@ class LLVM_ABI TargetLoweringBase {
   /// Indicate that the specified load with extension does not work with the
   /// specified type and indicate what to do about it.
   void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT,
-                        LegalizeAction Action) {
+                        LegalizeAction Action, unsigned AddrSpace = ~0) {
     assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() &&
            MemVT.isValid() && "Table isn't big enough!");
     assert((unsigned)Action < 0x10 && "too many bits for bitfield array");
+
     unsigned Shift = 4 * ExtType;
-    LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &= ~((uint16_t)0xF << Shift);
-    LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action << Shift;
+
+    if (AddrSpace == ~((unsigned)0)) {
+      LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &=
+          ~((uint16_t)0xF << Shift);
+      LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action
+                                                        << Shift;
+    } else {
+      uint64_t OverrideKey = ((uint64_t)(ValVT.SimpleTy & 0xFF) << 40) |
+                             ((uint64_t)(MemVT.SimpleTy & 0xFF) << 32) |
+                             (uint64_t)AddrSpace;
+      uint16_t &OverrideVal = LoadExtActionOverrides[OverrideKey];
+
+      OverrideVal &= ~((uint16_t)0xF << Shift);
+      OverrideVal |= (uint16_t)Action << Shift;
+    }
   }
   void setLoadExtAction(ArrayRef<unsigned> ExtTypes, MVT ValVT, MVT MemVT,
-                        LegalizeAction Action) {
+                        LegalizeAction Action, unsigned AddrSpace = ~0) {
     for (auto ExtType : ExtTypes)
-      setLoadExtAction(ExtType, ValVT, MemVT, Action);
+      setLoadExtAction(ExtType, ValVT, MemVT, Action, AddrSpace);
   }
   void setLoadExtAction(ArrayRef<unsigned> ExtTypes, MVT ValVT,
-                        ArrayRef<MVT> MemVTs, LegalizeAction Action) {
+                        ArrayRef<MVT> MemVTs, LegalizeAction Action,
+                        unsigned AddrSpace = ~0) {
     for (auto MemVT : MemVTs)
-      setLoadExtAction(ExtTypes, ValVT, MemVT, Action);
+      setLoadExtAction(ExtTypes, ValVT, MemVT, Action, AddrSpace);
   }
 
   /// Let target indicate that an extending atomic load of the specified type
@@ -3126,7 +3152,7 @@ class LLVM_ABI TargetLoweringBase {
       LType = ISD::SEXTLOAD;
     }
 
-    return isLoadExtLegal(LType, VT, LoadVT);
+    return isLoadExtLegal(LType, VT, LoadVT, Load->getPointerAddressSpace());
   }
 
   /// Return true if any actual instruction that defines a value of type FromTy
@@ -3753,8 +3779,11 @@ class LLVM_ABI TargetLoweringBase {
   /// For each load extension type and each value type, keep a LegalizeAction
   /// that indicates how instruction selection should deal with a load of a
   /// specific value type and extension type. Uses 4-bits to store the action
-  /// for each of the 4 load ext types.
+  /// for each of the 4 load ext types. These actions can be specified for each
+  /// address space.
   uint16_t LoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE];
+  using LoadExtActionOverrideMap = std::map<uint64_t, uint16_t>;
+  LoadExtActionOverrideMap LoadExtActionOverrides;
 
   /// Similar to LoadExtActions, but for atomic loads. Only Legal or Expand
   /// (default) values are supported.
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index eb73d01b3558c..0b7ddf1211f54 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -7347,7 +7347,8 @@ bool CodeGenPrepare::optimizeLoadExt(LoadInst *Load) {
 
   // Reject cases that won't be matched as extloads.
   if (!LoadResultVT.bitsGT(TruncVT) || !TruncVT.isRound() ||
-      !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT))
+      !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT,
+                           Load->getPointerAddressSpace()))
     return false;
 
   IRBuilder<> Builder(Load->getNextNode());
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 309f1bea8b77c..8a8de4ba97f92 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -6889,7 +6889,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
 
   if (ExtVT == LoadedVT &&
       (!LegalOperations ||
-       TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
+       TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT,
+                          LoadN->getAddressSpace()))) {
     // ZEXTLOAD will match without needing to change the size of the value being
     // loaded.
     return true;
@@ -6904,8 +6905,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
     return false;
 
-  if (LegalOperations &&
-      !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
+  if (LegalOperations && !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT,
+                                             LoadN->getAddressSpace()))
     return false;
 
   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0))
@@ -6967,8 +6968,8 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
     if (!SDValue(Load, 0).hasOneUse())
       return false;
 
-    if (LegalOperations &&
-        !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
+    if (LegalOperations && !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT,
+                                               Load->getAddressSpace()))
       return false;
 
     // For the transform to be legal, the load must produce only two values
@@ -7480,7 +7481,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
       EVT LoadVT = MLoad->getMemoryVT();
       EVT ExtVT = VT;
-      if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
+      if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT,
+                             MLoad->getAddressSpace())) {
         // For this AND to be a zero extension of the masked load the elements
         // of the BuildVec must mask the bottom bits of the extended element
         // type
@@ -7631,9 +7633,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
     // actually legal and isn't going to get expanded, else this is a false
     // optimisation.
-    bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
-                                                    Load->getValueType(0),
-                                                    Load->getMemoryVT());
+    bool CanZextLoadProfitably =
+        TLI.isLoadExtLegal(ISD::ZEXTLOAD, Load->getValueType(0),
+                           Load->getMemoryVT(), Load->getAddressSpace());
 
     // Resize the constant to the same size as the original memory access before
     // extension. If it is still the AllOnesValue then this AND is completely
@@ -7825,7 +7827,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
         ((!LegalOperations && LN0->isSimple()) ||
-         TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
+         TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN0->getAddressSpace()))) {
       SDValue ExtLoad =
           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
@@ -9747,10 +9749,13 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   // Before legalize we can introduce too wide illegal loads which will be later
   // split into legal sized loads. This enables us to combine i64 load by i8
   // patterns to a couple of i32 loads on 32 bit targets.
-  if (LegalOperations &&
-      !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
-                          MemVT))
-    return SDValue();
+  if (LegalOperations) {
+    for (auto *L : Loads) {
+      if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
+                              MemVT, L->getAddressSpace()))
+        return SDValue();
+    }
+  }
 
   // Check if the bytes of the OR we are looking at match with either big or
   // little endian value load
@@ -13425,9 +13430,11 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
       unsigned WideWidth = WideVT.getScalarSizeInBits();
       bool IsSigned = isSignedIntSetCC(CC);
       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
-      if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
-          SetCCWidth != 1 && SetCCWidth < WideWidth &&
-          TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
+      if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && SetCCWidth != 1 &&
+          SetCCWidth < WideWidth &&
+          TLI.isLoadExtLegalOrCustom(
+              LoadExtOpcode, WideVT, NarrowVT,
+              cast<LoadSDNode>(LHS)->getAddressSpace()) &&
           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
         // Both compare operands can be widened for free. The LHS can use an
         // extended load, and the RHS is a constant:
@@ -13874,8 +13881,10 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
   // Combine2), so we should conservatively check the OperationAction.
   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
-  if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
-      !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
+  if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT(),
+                          Load1->getAddressSpace()) ||
+      !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT(),
+                          Load2->getAddressSpace()) ||
       (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
        TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal))
     return SDValue();
@@ -14099,13 +14108,15 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
   // Try to split the vector types to get down to legal types.
   EVT SplitSrcVT = SrcVT;
   EVT SplitDstVT = DstVT;
-  while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
+  while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT,
+                                     LN0->getAddressSpace()) &&
          SplitSrcVT.getVectorNumElements() > 1) {
     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
   }
 
-  if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
+  if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT,
+                                  LN0->getAddressSpace()))
     return SDValue();
 
   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
@@ -14178,7 +14189,7 @@ SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
     return SDValue();
   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
   EVT MemVT = Load->getMemoryVT();
-  if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
+  if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, Load->getAddressSpace()) ||
       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
     return SDValue();
 
@@ -14286,9 +14297,8 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
 
   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   EVT MemVT = LN0->getMemoryVT();
-  if ((LegalOperations || !LN0->isSimple() ||
-       VT.isVector()) &&
-      !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
+  if ((LegalOperations || !LN0->isSimple() || VT.isVector()) &&
+      !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT, LN0->getAddressSpace()))
     return SDValue();
 
   SDValue ExtLoad =
@@ -14330,12 +14340,13 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
     }
   }
 
+  LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   // TODO: isFixedLengthVector() should be removed and any negative effects on
   // code generation being the result of that target's implementation of
   // isVectorLoadExtDesirable().
-  if ((LegalOperations || VT.isFixedLengthVector() ||
-       !cast<LoadSDNode>(N0)->isSimple()) &&
-      !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))
+  if ((LegalOperations || VT.isFixedLengthVector() || !LN0->isSimple()) &&
+      !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(),
+                          LN0->getAddressSpace()))
     return {};
 
   bool DoXform = true;
@@ -14347,7 +14358,6 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
   if (!DoXform)
     return {};
 
-  LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
                                    LN0->getBasePtr(), N0.getValueType(),
                                    LN0->getMemOperand());
@@ -14377,8 +14387,9 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
     return SDValue();
 
-  if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
-      !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
+  if ((LegalOperations || !Ld->isSimple()) &&
+      !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0),
+                                  Ld->getAddressSpace()))
     return SDValue();
 
   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
@@ -14522,7 +14533,8 @@ SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
               ISD::isUNINDEXEDLoad(V.getNode()) &&
               cast<LoadSDNode>(V)->isSimple() &&
-              TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
+              TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType(),
+                                 cast<LoadSDNode>(V)->getAddressSpace())))
           return false;
 
         // Non-chain users of this value must either be the setcc in this
@@ -14719,8 +14731,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
     EVT MemVT = LN00->getMemoryVT();
-    if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
-      LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
+    if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT, LN00->getAddressSpace()) &&
+        LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
       SmallVector<SDNode*, 4> SetCCs;
       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
                                              ISD::SIGN_EXTEND, SetCCs, TLI);
@@ -15037,7 +15049,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
     EVT MemVT = LN00->getMemoryVT();
-    if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
+    if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN00->getAddressSpace()) &&
         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
       bool DoXform = true;
       SmallVector<SDNode*, 4> SetCCs;
@@ -15268,7 +15280,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
       return foldedExt;
   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
              ISD::isUNINDEXEDLoad(N0.getNode()) &&
-             TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2025

@llvm/pr-subscribers-backend-x86

Author: Demetrius Kanios (QuantumSegfault)

Changes

Reboot of #157627

Makes TargetLoweringBase::getLoadExtAction, TargetLoweringBase::setLoadExtAction, TargetLoweringBase::isLoadExtLegal and TargetLoweringBaes::isLoadExtLegalOrCustom dependent on address space, rather than applying all-or-nothing logic to all address spaces.


Patch is 32.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162407.diff

10 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+12-4)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+48-19)
  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+70-46)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+11-7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+2-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42ddb32d24093..1439683dc5e96 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -28,6 +28,7 @@
 #include "llvm/Analysis/TargetTransformInfoImpl.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/TargetSubtargetInfo.h"
 #include "llvm/CodeGen/ValueTypes.h"
@@ -1244,9 +1245,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
         EVT LoadVT = EVT::getEVT(Src);
         unsigned LType =
           ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD);
-        if (DstLT.first == SrcLT.first &&
-            TLI->isLoadExtLegal(LType, ExtVT, LoadVT))
-          return 0;
+
+        if (I && isa<LoadInst>(I->getOperand(0))) {
+          auto *LI = cast<LoadInst>(I->getOperand(0));
+
+          if (DstLT.first == SrcLT.first &&
+              TLI->isLoadExtLegal(LType, ExtVT, LoadVT,
+                                  LI->getPointerAddressSpace()))
+            return 0;
+        }
       }
       break;
     case Instruction::AddrSpaceCast:
@@ -1531,7 +1538,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
       if (Opcode == Instruction::Store)
         LA = getTLI()->getTruncStoreAction(LT.second, MemVT);
       else
-        LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT);
+        LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT,
+                                        AddressSpace);
 
       if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) {
         // This is a vector load/store for some illegal type that is scalarized.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 88691b931a8d8..0b160443ade59 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1472,27 +1472,38 @@ class LLVM_ABI TargetLoweringBase {
   /// Return how this load with extension should be treated: either it is legal,
   /// needs to be promoted to a larger size, needs to be expanded to some other
   /// code sequence, or the target has a custom expander for it.
-  LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT,
-                                  EVT MemVT) const {
-    if (ValVT.isExtended() || MemVT.isExtended()) return Expand;
-    unsigned ValI = (unsigned) ValVT.getSimpleVT().SimpleTy;
-    unsigned MemI = (unsigned) MemVT.getSimpleVT().SimpleTy;
+  LegalizeAction getLoadExtAction(unsigned ExtType, EVT ValVT, EVT MemVT,
+                                  unsigned AddrSpace) const {
+    if (ValVT.isExtended() || MemVT.isExtended())
+      return Expand;
+    unsigned ValI = (unsigned)ValVT.getSimpleVT().SimpleTy;
+    unsigned MemI = (unsigned)MemVT.getSimpleVT().SimpleTy;
     assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValI < MVT::VALUETYPE_SIZE &&
            MemI < MVT::VALUETYPE_SIZE && "Table isn't big enough!");
     unsigned Shift = 4 * ExtType;
+
+    uint64_t OverrideKey = ((uint64_t)(ValI & 0xFF) << 40) |
+                           ((uint64_t)(MemI & 0xFF) << 32) |
+                           (uint64_t)AddrSpace;
+
+    if (LoadExtActionOverrides.count(OverrideKey)) {
+      return (LegalizeAction)((LoadExtActionOverrides.at(OverrideKey) >> Shift) & 0xf);
+    }
     return (LegalizeAction)((LoadExtActions[ValI][MemI] >> Shift) & 0xf);
   }
 
   /// Return true if the specified load with extension is legal on this target.
-  bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT) const {
-    return getLoadExtAction(ExtType, ValVT, MemVT) == Legal;
+  bool isLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT,
+                      unsigned AddrSpace) const {
+    return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal;
   }
 
   /// Return true if the specified load with extension is legal or custom
   /// on this target.
-  bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT) const {
-    return getLoadExtAction(ExtType, ValVT, MemVT) == Legal ||
-           getLoadExtAction(ExtType, ValVT, MemVT) == Custom;
+  bool isLoadExtLegalOrCustom(unsigned ExtType, EVT ValVT, EVT MemVT,
+                              unsigned AddrSpace) const {
+    return getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Legal ||
+           getLoadExtAction(ExtType, ValVT, MemVT, AddrSpace) == Custom;
   }
 
   /// Same as getLoadExtAction, but for atomic loads.
@@ -2634,23 +2645,38 @@ class LLVM_ABI TargetLoweringBase {
   /// Indicate that the specified load with extension does not work with the
   /// specified type and indicate what to do about it.
   void setLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT,
-                        LegalizeAction Action) {
+                        LegalizeAction Action, unsigned AddrSpace = ~0) {
     assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() &&
            MemVT.isValid() && "Table isn't big enough!");
     assert((unsigned)Action < 0x10 && "too many bits for bitfield array");
+
     unsigned Shift = 4 * ExtType;
-    LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &= ~((uint16_t)0xF << Shift);
-    LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action << Shift;
+
+    if (AddrSpace == ~((unsigned)0)) {
+      LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &=
+          ~((uint16_t)0xF << Shift);
+      LoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |= (uint16_t)Action
+                                                        << Shift;
+    } else {
+      uint64_t OverrideKey = ((uint64_t)(ValVT.SimpleTy & 0xFF) << 40) |
+                             ((uint64_t)(MemVT.SimpleTy & 0xFF) << 32) |
+                             (uint64_t)AddrSpace;
+      uint16_t &OverrideVal = LoadExtActionOverrides[OverrideKey];
+
+      OverrideVal &= ~((uint16_t)0xF << Shift);
+      OverrideVal |= (uint16_t)Action << Shift;
+    }
   }
   void setLoadExtAction(ArrayRef<unsigned> ExtTypes, MVT ValVT, MVT MemVT,
-                        LegalizeAction Action) {
+                        LegalizeAction Action, unsigned AddrSpace = ~0) {
     for (auto ExtType : ExtTypes)
-      setLoadExtAction(ExtType, ValVT, MemVT, Action);
+      setLoadExtAction(ExtType, ValVT, MemVT, Action, AddrSpace);
   }
   void setLoadExtAction(ArrayRef<unsigned> ExtTypes, MVT ValVT,
-                        ArrayRef<MVT> MemVTs, LegalizeAction Action) {
+                        ArrayRef<MVT> MemVTs, LegalizeAction Action,
+                        unsigned AddrSpace = ~0) {
     for (auto MemVT : MemVTs)
-      setLoadExtAction(ExtTypes, ValVT, MemVT, Action);
+      setLoadExtAction(ExtTypes, ValVT, MemVT, Action, AddrSpace);
   }
 
   /// Let target indicate that an extending atomic load of the specified type
@@ -3126,7 +3152,7 @@ class LLVM_ABI TargetLoweringBase {
       LType = ISD::SEXTLOAD;
     }
 
-    return isLoadExtLegal(LType, VT, LoadVT);
+    return isLoadExtLegal(LType, VT, LoadVT, Load->getPointerAddressSpace());
   }
 
   /// Return true if any actual instruction that defines a value of type FromTy
@@ -3753,8 +3779,11 @@ class LLVM_ABI TargetLoweringBase {
   /// For each load extension type and each value type, keep a LegalizeAction
   /// that indicates how instruction selection should deal with a load of a
   /// specific value type and extension type. Uses 4-bits to store the action
-  /// for each of the 4 load ext types.
+  /// for each of the 4 load ext types. These actions can be specified for each
+  /// address space.
   uint16_t LoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE];
+  using LoadExtActionOverrideMap = std::map<uint64_t, uint16_t>;
+  LoadExtActionOverrideMap LoadExtActionOverrides;
 
   /// Similar to LoadExtActions, but for atomic loads. Only Legal or Expand
   /// (default) values are supported.
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index eb73d01b3558c..0b7ddf1211f54 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -7347,7 +7347,8 @@ bool CodeGenPrepare::optimizeLoadExt(LoadInst *Load) {
 
   // Reject cases that won't be matched as extloads.
   if (!LoadResultVT.bitsGT(TruncVT) || !TruncVT.isRound() ||
-      !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT))
+      !TLI->isLoadExtLegal(ISD::ZEXTLOAD, LoadResultVT, TruncVT,
+                           Load->getPointerAddressSpace()))
     return false;
 
   IRBuilder<> Builder(Load->getNextNode());
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 309f1bea8b77c..8a8de4ba97f92 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -6889,7 +6889,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
 
   if (ExtVT == LoadedVT &&
       (!LegalOperations ||
-       TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) {
+       TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT,
+                          LoadN->getAddressSpace()))) {
     // ZEXTLOAD will match without needing to change the size of the value being
     // loaded.
     return true;
@@ -6904,8 +6905,8 @@ bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN,
   if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound())
     return false;
 
-  if (LegalOperations &&
-      !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))
+  if (LegalOperations && !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT,
+                                             LoadN->getAddressSpace()))
     return false;
 
   if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT, /*ByteOffset=*/0))
@@ -6967,8 +6968,8 @@ bool DAGCombiner::isLegalNarrowLdSt(LSBaseSDNode *LDST,
     if (!SDValue(Load, 0).hasOneUse())
       return false;
 
-    if (LegalOperations &&
-        !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT))
+    if (LegalOperations && !TLI.isLoadExtLegal(ExtType, Load->getValueType(0), MemVT,
+                                               Load->getAddressSpace()))
       return false;
 
     // For the transform to be legal, the load must produce only two values
@@ -7480,7 +7481,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) {
       EVT LoadVT = MLoad->getMemoryVT();
       EVT ExtVT = VT;
-      if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) {
+      if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT,
+                             MLoad->getAddressSpace())) {
         // For this AND to be a zero extension of the masked load the elements
         // of the BuildVec must mask the bottom bits of the extended element
         // type
@@ -7631,9 +7633,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     // If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
     // actually legal and isn't going to get expanded, else this is a false
     // optimisation.
-    bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
-                                                    Load->getValueType(0),
-                                                    Load->getMemoryVT());
+    bool CanZextLoadProfitably =
+        TLI.isLoadExtLegal(ISD::ZEXTLOAD, Load->getValueType(0),
+                           Load->getMemoryVT(), Load->getAddressSpace());
 
     // Resize the constant to the same size as the original memory access before
     // extension. If it is still the AllOnesValue then this AND is completely
@@ -7825,7 +7827,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     APInt ExtBits = APInt::getHighBitsSet(ExtBitSize, ExtBitSize - MemBitSize);
     if (DAG.MaskedValueIsZero(N1, ExtBits) &&
         ((!LegalOperations && LN0->isSimple()) ||
-         TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT))) {
+         TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN0->getAddressSpace()))) {
       SDValue ExtLoad =
           DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(N0), VT, LN0->getChain(),
                          LN0->getBasePtr(), MemVT, LN0->getMemOperand());
@@ -9747,10 +9749,13 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   // Before legalize we can introduce too wide illegal loads which will be later
   // split into legal sized loads. This enables us to combine i64 load by i8
   // patterns to a couple of i32 loads on 32 bit targets.
-  if (LegalOperations &&
-      !TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
-                          MemVT))
-    return SDValue();
+  if (LegalOperations) {
+    for (auto *L : Loads) {
+      if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
+                              MemVT, L->getAddressSpace()))
+        return SDValue();
+    }
+  }
 
   // Check if the bytes of the OR we are looking at match with either big or
   // little endian value load
@@ -13425,9 +13430,11 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
       unsigned WideWidth = WideVT.getScalarSizeInBits();
       bool IsSigned = isSignedIntSetCC(CC);
       auto LoadExtOpcode = IsSigned ? ISD::SEXTLOAD : ISD::ZEXTLOAD;
-      if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() &&
-          SetCCWidth != 1 && SetCCWidth < WideWidth &&
-          TLI.isLoadExtLegalOrCustom(LoadExtOpcode, WideVT, NarrowVT) &&
+      if (LHS.getOpcode() == ISD::LOAD && LHS.hasOneUse() && SetCCWidth != 1 &&
+          SetCCWidth < WideWidth &&
+          TLI.isLoadExtLegalOrCustom(
+              LoadExtOpcode, WideVT, NarrowVT,
+              cast<LoadSDNode>(LHS)->getAddressSpace()) &&
           TLI.isOperationLegalOrCustom(ISD::SETCC, WideVT)) {
         // Both compare operands can be widened for free. The LHS can use an
         // extended load, and the RHS is a constant:
@@ -13874,8 +13881,10 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
   // Combine2), so we should conservatively check the OperationAction.
   LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
   LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
-  if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
-      !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()) ||
+  if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT(),
+                          Load1->getAddressSpace()) ||
+      !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT(),
+                          Load2->getAddressSpace()) ||
       (N0->getOpcode() == ISD::VSELECT && Level >= AfterLegalizeTypes &&
        TLI.getOperationAction(ISD::VSELECT, VT) != TargetLowering::Legal))
     return SDValue();
@@ -14099,13 +14108,15 @@ SDValue DAGCombiner::CombineExtLoad(SDNode *N) {
   // Try to split the vector types to get down to legal types.
   EVT SplitSrcVT = SrcVT;
   EVT SplitDstVT = DstVT;
-  while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT) &&
+  while (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT,
+                                     LN0->getAddressSpace()) &&
          SplitSrcVT.getVectorNumElements() > 1) {
     SplitDstVT = DAG.GetSplitDestVTs(SplitDstVT).first;
     SplitSrcVT = DAG.GetSplitDestVTs(SplitSrcVT).first;
   }
 
-  if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT))
+  if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT,
+                                  LN0->getAddressSpace()))
     return SDValue();
 
   assert(!DstVT.isScalableVector() && "Unexpected scalable vector type");
@@ -14178,7 +14189,7 @@ SDValue DAGCombiner::CombineZExtLogicopShiftLoad(SDNode *N) {
     return SDValue();
   LoadSDNode *Load = cast<LoadSDNode>(N1.getOperand(0));
   EVT MemVT = Load->getMemoryVT();
-  if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) ||
+  if (!TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, Load->getAddressSpace()) ||
       Load->getExtensionType() == ISD::SEXTLOAD || Load->isIndexed())
     return SDValue();
 
@@ -14286,9 +14297,8 @@ static SDValue tryToFoldExtOfExtload(SelectionDAG &DAG, DAGCombiner &Combiner,
 
   LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   EVT MemVT = LN0->getMemoryVT();
-  if ((LegalOperations || !LN0->isSimple() ||
-       VT.isVector()) &&
-      !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT))
+  if ((LegalOperations || !LN0->isSimple() || VT.isVector()) &&
+      !TLI.isLoadExtLegal(ExtLoadType, VT, MemVT, LN0->getAddressSpace()))
     return SDValue();
 
   SDValue ExtLoad =
@@ -14330,12 +14340,13 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
     }
   }
 
+  LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   // TODO: isFixedLengthVector() should be removed and any negative effects on
   // code generation being the result of that target's implementation of
   // isVectorLoadExtDesirable().
-  if ((LegalOperations || VT.isFixedLengthVector() ||
-       !cast<LoadSDNode>(N0)->isSimple()) &&
-      !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType()))
+  if ((LegalOperations || VT.isFixedLengthVector() || !LN0->isSimple()) &&
+      !TLI.isLoadExtLegal(ExtLoadType, VT, N0.getValueType(),
+                          LN0->getAddressSpace()))
     return {};
 
   bool DoXform = true;
@@ -14347,7 +14358,6 @@ static SDValue tryToFoldExtOfLoad(SelectionDAG &DAG, DAGCombiner &Combiner,
   if (!DoXform)
     return {};
 
-  LoadSDNode *LN0 = cast<LoadSDNode>(N0);
   SDValue ExtLoad = DAG.getExtLoad(ExtLoadType, SDLoc(LN0), VT, LN0->getChain(),
                                    LN0->getBasePtr(), N0.getValueType(),
                                    LN0->getMemOperand());
@@ -14377,8 +14387,9 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
   if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD)
     return SDValue();
 
-  if ((LegalOperations || !cast<MaskedLoadSDNode>(N0)->isSimple()) &&
-      !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0)))
+  if ((LegalOperations || !Ld->isSimple()) &&
+      !TLI.isLoadExtLegalOrCustom(ExtLoadType, VT, Ld->getValueType(0),
+                                  Ld->getAddressSpace()))
     return SDValue();
 
   if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0)))
@@ -14522,7 +14533,8 @@ SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
         if (!(ISD::isNON_EXTLoad(V.getNode()) &&
               ISD::isUNINDEXEDLoad(V.getNode()) &&
               cast<LoadSDNode>(V)->isSimple() &&
-              TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType())))
+              TLI.isLoadExtLegal(LoadOpcode, VT, V.getValueType(),
+                                 cast<LoadSDNode>(V)->getAddressSpace())))
           return false;
 
         // Non-chain users of this value must either be the setcc in this
@@ -14719,8 +14731,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
     EVT MemVT = LN00->getMemoryVT();
-    if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT) &&
-      LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
+    if (TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, MemVT, LN00->getAddressSpace()) &&
+        LN00->getExtensionType() != ISD::ZEXTLOAD && LN00->isUnindexed()) {
       SmallVector<SDNode*, 4> SetCCs;
       bool DoXform = ExtendUsesToFormExtLoad(VT, N0.getNode(), N0.getOperand(0),
                                              ISD::SIGN_EXTEND, SetCCs, TLI);
@@ -15037,7 +15049,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
       (!LegalOperations && TLI.isOperationLegal(N0.getOpcode(), VT))) {
     LoadSDNode *LN00 = cast<LoadSDNode>(N0.getOperand(0));
     EVT MemVT = LN00->getMemoryVT();
-    if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT) &&
+    if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT, LN00->getAddressSpace()) &&
         LN00->getExtensionType() != ISD::SEXTLOAD && LN00->isUnindexed()) {
       bool DoXform = true;
       SmallVector<SDNode*, 4> SetCCs;
@@ -15268,7 +15280,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
       return foldedExt;
   } else if (ISD::isNON_EXTLoad(N0.getNode()) &&
              ISD::isUNINDEXEDLoad(N0.getNode()) &&
-             TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, ...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants