Skip to content

Conversation

@farzonl
Copy link
Member

@farzonl farzonl commented Jun 19, 2025

fixes #144608

  • there is a getPointerOperandIndex function so we don't need to iterate the operands trying to find the pointer. This resulted in a small cleanup to visitStoreInst and visitLoadInst.

  • The meat of this change was in visitGetElementPtrInst to account for allocas and not bail when we don't find a global.

fixes llvm#144608
- there is a getPointerOperandIndex function so we don't need to iterate
  the operands trying to find the pointer. This resulted in a small
  cleanup to visitStoreInst and visitLoadInst.

- The meat of this change was in visitGetElementPtrInst to account for
  allocas and not bail when we don't find a global.
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes

fixes #144608

  • there is a getPointerOperandIndex function so we don't need to iterate the operands trying to find the pointer. This resulted in a small cleanup to visitStoreInst and visitLoadInst.

  • The meat of this change was in visitGetElementPtrInst to account for allocas and not bail when we don't find a global.


Full diff: https://github.com/llvm/llvm-project/pull/144959.diff

2 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILDataScalarization.cpp (+55-49)
  • (modified) llvm/test/CodeGen/DirectX/scalarize-alloca.ll (+17-2)
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index 06708cec00cec..61c5301ed5051 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -14,11 +14,13 @@
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/ReplaceConstant.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/Local.h"
 
@@ -127,71 +129,75 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
 }
 
 bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
-  unsigned NumOperands = LI.getNumOperands();
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = LI.getOperand(I);
-    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
-    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      GetElementPtrInst *OldGEP =
-          cast<GetElementPtrInst>(CE->getAsInstruction());
-      OldGEP->insertBefore(LI.getIterator());
-      IRBuilder<> Builder(&LI);
-      LoadInst *NewLoad =
-          Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
-      NewLoad->setAlignment(LI.getAlign());
-      LI.replaceAllUsesWith(NewLoad);
-      LI.eraseFromParent();
-      visitGetElementPtrInst(*OldGEP);
-      return true;
-    }
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
-      LI.setOperand(I, NewGlobal);
+  Value *PtrOperand = LI.getPointerOperand();
+  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
+  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
+    OldGEP->insertBefore(LI.getIterator());
+    IRBuilder<> Builder(&LI);
+    LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
+    NewLoad->setAlignment(LI.getAlign());
+    LI.replaceAllUsesWith(NewLoad);
+    LI.eraseFromParent();
+    visitGetElementPtrInst(*OldGEP);
+    return true;
   }
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
+    LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
   return false;
 }
 
 bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
-  unsigned NumOperands = SI.getNumOperands();
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = SI.getOperand(I);
-    ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
-    if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
-      GetElementPtrInst *OldGEP =
-          cast<GetElementPtrInst>(CE->getAsInstruction());
-      OldGEP->insertBefore(SI.getIterator());
-      IRBuilder<> Builder(&SI);
-      StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
-      NewStore->setAlignment(SI.getAlign());
-      SI.replaceAllUsesWith(NewStore);
-      SI.eraseFromParent();
-      visitGetElementPtrInst(*OldGEP);
-      return true;
-    }
-    if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
-      SI.setOperand(I, NewGlobal);
+
+  Value *PtrOperand = SI.getPointerOperand();
+  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
+  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
+    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
+    OldGEP->insertBefore(SI.getIterator());
+    IRBuilder<> Builder(&SI);
+    StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
+    NewStore->setAlignment(SI.getAlign());
+    SI.replaceAllUsesWith(NewStore);
+    SI.eraseFromParent();
+    visitGetElementPtrInst(*OldGEP);
+    return true;
   }
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
+    SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
+
   return false;
 }
 
 bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-
-  unsigned NumOperands = GEPI.getNumOperands();
-  GlobalVariable *NewGlobal = nullptr;
-  for (unsigned I = 0; I < NumOperands; ++I) {
-    Value *CurrOpperand = GEPI.getOperand(I);
-    NewGlobal = lookupReplacementGlobal(CurrOpperand);
-    if (NewGlobal)
-      break;
+  Value *PtrOperand = GEPI.getPointerOperand();
+  Type *OrigGEPType = GEPI.getPointerOperandType();
+  Type *NewGEPType = OrigGEPType;
+  bool NeedsTransform = false;
+
+  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
+    NewGEPType = NewGlobal->getValueType();
+    PtrOperand = NewGlobal;
+    NeedsTransform = true;
+  } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
+    Type *AllocatedType = Alloca->getAllocatedType();
+    // OrigGEPType might just be a pointer lets make sure
+    // to add the allocated type so we have a size
+    if (AllocatedType != OrigGEPType) {
+      NewGEPType = AllocatedType;
+      NeedsTransform = true;
+    }
   }
-  if (!NewGlobal)
+
+  // Note: We bail if this isn't a gep touched via alloca or global
+  // transformations
+  if (!NeedsTransform)
     return false;
 
   IRBuilder<> Builder(&GEPI);
   SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
 
-  Value *NewGEP =
-      Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
-                        GEPI.getName(), GEPI.getNoWrapFlags());
+  Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
+                                    GEPI.getName(), GEPI.getNoWrapFlags());
   GEPI.replaceAllUsesWith(NewGEP);
   GEPI.eraseFromParent();
   return true;
diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
index 4829f3a31791f..b589136d6965c 100644
--- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
+++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll
@@ -1,10 +1,25 @@
-; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
-; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
+; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
+; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
 
 ; CHECK-LABEL: alloca_2d__vec_test
 define void @alloca_2d__vec_test() local_unnamed_addr #2 {
   ; SCHECK:  alloca [2 x [4 x i32]], align 16
   ; FCHECK:  alloca [8 x i32], align 16
+  ; CHECK: ret void
   %1 = alloca [2 x <4 x i32>], align 16
   ret void
 }
+
+; CHECK-LABEL: alloca_2d_gep_test
+define void @alloca_2d_gep_test() {
+  ; SCHECK:  [[alloca_val:%.*]] = alloca [2 x [2 x i32]], align 16
+  ; FCHECK:  [[alloca_val:%.*]] = alloca [4 x i32], align 16
+  ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
+  ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]]
+  ; CHECK: ret void
+  %1 = alloca [2 x <2 x i32>], align 16
+  %2 = tail call i32 @llvm.dx.thread.id(i32 0)
+  %3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
+  ret void
+}

@farzonl farzonl merged commit 2a4207e into llvm:main Jun 20, 2025
10 checks passed
@damyanp damyanp removed this from HLSL Support Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DirectX] DXIL Data Scalarization is not scalarizing GEP for an array of vectors in function parameter

4 participants