2020#include " llvm/IR/DerivedTypes.h"
2121#include " llvm/IR/IRBuilder.h"
2222#include " llvm/IR/InstVisitor.h"
23+ #include " llvm/IR/ReplaceConstant.h"
2324#include " llvm/Support/Casting.h"
2425#include " llvm/Transforms/Utils/Local.h"
2526#include < cassert>
@@ -69,8 +70,8 @@ class DXILFlattenArraysVisitor
6970 bool visitExtractElementInst (ExtractElementInst &EEI) { return false ; }
7071 bool visitShuffleVectorInst (ShuffleVectorInst &SVI) { return false ; }
7172 bool visitPHINode (PHINode &PHI) { return false ; }
72- bool visitLoadInst (LoadInst &LI) { return false ; }
73- bool visitStoreInst (StoreInst &SI) { return false ; }
73+ bool visitLoadInst (LoadInst &LI);
74+ bool visitStoreInst (StoreInst &SI);
7475 bool visitCallInst (CallInst &ICI) { return false ; }
7576 bool visitFreezeInst (FreezeInst &FI) { return false ; }
7677 static bool isMultiDimensionalArray (Type *T);
@@ -94,7 +95,6 @@ class DXILFlattenArraysVisitor
9495 SmallVector<Value *> Indices = SmallVector<Value *>(),
9596 SmallVector<uint64_t > Dims = SmallVector<uint64_t >(),
9697 bool AllIndicesAreConstInt = true );
97- ConstantInt *computeFlatIndex (GetElementPtrInst &GEP);
9898 bool visitGetElementPtrInstInGEPChain (GetElementPtrInst &GEP);
9999 bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
100100 GetElementPtrInst &GEP);
@@ -164,6 +164,38 @@ Value *DXILFlattenArraysVisitor::instructionFlattenIndices(
164164 return FlatIndex;
165165}
166166
167+ bool DXILFlattenArraysVisitor::visitLoadInst (LoadInst &LI) {
168+ unsigned NumOperands = LI.getNumOperands ();
169+ for (unsigned I = 0 ; I < NumOperands; ++I) {
170+ Value *CurrOpperand = LI.getOperand (I);
171+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
172+ if (CE && CE->getOpcode () == Instruction::GetElementPtr) {
173+ convertUsersOfConstantsToInstructions (CE,
174+ /* RestrictToFunc=*/ nullptr ,
175+ /* RemoveDeadConstants=*/ false ,
176+ /* IncludeSelf=*/ true );
177+ return false ;
178+ }
179+ }
180+ return false ;
181+ }
182+
183+ bool DXILFlattenArraysVisitor::visitStoreInst (StoreInst &SI) {
184+ unsigned NumOperands = SI.getNumOperands ();
185+ for (unsigned I = 0 ; I < NumOperands; ++I) {
186+ Value *CurrOpperand = SI.getOperand (I);
187+ ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
188+ if (CE && CE->getOpcode () == Instruction::GetElementPtr) {
189+ convertUsersOfConstantsToInstructions (CE,
190+ /* RestrictToFunc=*/ nullptr ,
191+ /* RemoveDeadConstants=*/ false ,
192+ /* IncludeSelf=*/ true );
193+ return false ;
194+ }
195+ }
196+ return false ;
197+ }
198+
167199bool DXILFlattenArraysVisitor::visitAllocaInst (AllocaInst &AI) {
168200 if (!isMultiDimensionalArray (AI.getAllocatedType ()))
169201 return false ;
@@ -182,41 +214,6 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
182214 return true ;
183215}
184216
185- ConstantInt *
186- DXILFlattenArraysVisitor::computeFlatIndex (GetElementPtrInst &GEP) {
187- unsigned IndexAmount = GEP.getNumIndices ();
188- assert (IndexAmount >= 1 && " Need At least one Index" );
189- if (IndexAmount == 1 )
190- return dyn_cast<ConstantInt>(GEP.getOperand (GEP.getNumOperands () - 1 ));
191-
192- // Get the type of the base pointer.
193- Type *BaseType = GEP.getSourceElementType ();
194-
195- // Determine the dimensions of the multi-dimensional array.
196- SmallVector<int64_t > Dimensions;
197- while (auto *ArrType = dyn_cast<ArrayType>(BaseType)) {
198- Dimensions.push_back (ArrType->getNumElements ());
199- BaseType = ArrType->getElementType ();
200- }
201- unsigned FlatIndex = 0 ;
202- unsigned Multiplier = 1 ;
203- unsigned BitWidth = 32 ;
204- for (const Use &Index : GEP.indices ()) {
205- ConstantInt *CurrentIndex = dyn_cast<ConstantInt>(Index);
206- BitWidth = CurrentIndex->getBitWidth ();
207- if (!CurrentIndex)
208- return nullptr ;
209- int64_t IndexValue = CurrentIndex->getSExtValue ();
210- FlatIndex += IndexValue * Multiplier;
211-
212- if (!Dimensions.empty ()) {
213- Multiplier *= Dimensions.back (); // Use the last dimension size
214- Dimensions.pop_back (); // Remove the last dimension
215- }
216- }
217- return ConstantInt::get (GEP.getContext (), APInt (BitWidth, FlatIndex));
218- }
219-
220217void DXILFlattenArraysVisitor::recursivelyCollectGEPs (
221218 GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
222219 Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
@@ -240,12 +237,13 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
240237 for (auto *User : CurrGEP.users ()) {
241238 if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
242239 recursivelyCollectGEPs (*NestedGEP, FlattenedArrayType, PtrOperand,
243- ++GEPChainUseCount, Indices, Dims, AllIndicesAreConstInt);
240+ ++GEPChainUseCount, Indices, Dims,
241+ AllIndicesAreConstInt);
244242 GepUses = true ;
245243 }
246244 }
247245 // This case is just incase the gep chain doesn't end with a 1d array.
248- if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
246+ if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
249247 GEPChainMap.insert (
250248 {&CurrGEP,
251249 {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
@@ -295,10 +293,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
295293
296294 unsigned GEPChainUseCount = 0 ;
297295 recursivelyCollectGEPs (GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
298-
296+
299297 // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
300298 // Here recursion is used to get the length of the GEP chain.
301- // Handle zero uses here because there won't be an update via
299+ // Handle zero uses here because there won't be an update via
302300 // a child in the chain later.
303301 if (GEPChainUseCount == 0 ) {
304302 SmallVector<Value *> Indices ({GEP.getOperand (GEP.getNumOperands () - 1 )});
@@ -308,7 +306,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
308306 std::move (Indices), std::move (Dims), AllIndicesAreConstInt};
309307 return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
310308 }
311-
309+
312310 PotentiallyDeadInstrs.emplace_back (&GEP);
313311 return false ;
314312}
@@ -426,7 +424,7 @@ static bool flattenArrays(Module &M) {
426424 for (auto &[Old, New] : GlobalMap) {
427425 Old->replaceAllUsesWith (New);
428426 Old->eraseFromParent ();
429- MadeChange | = true ;
427+ MadeChange = true ;
430428 }
431429 return MadeChange;
432430}
0 commit comments