2020#include " llvm/IR/InstVisitor.h"
2121#include " llvm/IR/ReplaceConstant.h"
2222#include " llvm/Support/Casting.h"
23+ #include " llvm/Support/MathExtras.h"
2324#include " llvm/Transforms/Utils/Local.h"
2425#include < cassert>
2526#include < cstddef>
@@ -40,18 +41,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
4041 static char ID; // Pass identification.
4142};
4243
43- struct GEPData {
44- ArrayType *ParentArrayType;
45- Value *ParentOperand;
46- SmallVector<Value *> Indices;
47- SmallVector<uint64_t > Dims;
48- bool AllIndicesAreConstInt;
44+ struct GEPInfo {
45+ ArrayType *RootFlattenedArrayType;
46+ Value *RootPointerOperand;
47+ SmallMapVector<Value *, APInt, 4 > VariableOffsets;
48+ APInt ConstantOffset;
4949};
5050
5151class DXILFlattenArraysVisitor
5252 : public InstVisitor<DXILFlattenArraysVisitor, bool > {
5353public:
54- DXILFlattenArraysVisitor () {}
54+ DXILFlattenArraysVisitor (
55+ SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
56+ : GlobalMap(GlobalMap) {}
5557 bool visit (Function &F);
5658 // InstVisitor methods. They return true if the instruction was scalarized,
5759 // false if nothing changed.
@@ -78,35 +80,20 @@ class DXILFlattenArraysVisitor
7880
7981private:
8082 SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81- DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
83+ SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
84+ SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
8285 bool finish ();
8386 ConstantInt *genConstFlattenIndices (ArrayRef<Value *> Indices,
8487 ArrayRef<uint64_t > Dims,
8588 IRBuilder<> &Builder);
8689 Value *genInstructionFlattenIndices (ArrayRef<Value *> Indices,
8790 ArrayRef<uint64_t > Dims,
8891 IRBuilder<> &Builder);
89-
90- // Helper function to collect indices and dimensions from a GEP instruction
91- void collectIndicesAndDimsFromGEP (GetElementPtrInst &GEP,
92- SmallVectorImpl<Value *> &Indices,
93- SmallVectorImpl<uint64_t > &Dims,
94- bool &AllIndicesAreConstInt);
95-
96- void
97- recursivelyCollectGEPs (GetElementPtrInst &CurrGEP,
98- ArrayType *FlattenedArrayType, Value *PtrOperand,
99- unsigned &GEPChainUseCount,
100- SmallVector<Value *> Indices = SmallVector<Value *>(),
101- SmallVector<uint64_t > Dims = SmallVector<uint64_t >(),
102- bool AllIndicesAreConstInt = true );
103- bool visitGetElementPtrInstInGEPChain (GetElementPtrInst &GEP);
104- bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
105- GetElementPtrInst &GEP);
10692};
10793} // namespace
10894
10995bool DXILFlattenArraysVisitor::finish () {
96+ GEPChainInfoMap.clear ();
11097 RecursivelyDeleteTriviallyDeadInstructionsPermissive (PotentiallyDeadInstrs);
11198 return true ;
11299}
@@ -225,131 +212,149 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
225212 return true ;
226213}
227214
228- void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP (
229- GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
230- SmallVectorImpl<uint64_t > &Dims, bool &AllIndicesAreConstInt) {
231-
232- Type *CurrentType = GEP.getSourceElementType ();
233-
234- // Note index 0 is the ptr index.
235- for (Value *Index : llvm::drop_begin (GEP.indices (), 1 )) {
236- Indices.push_back (Index);
237- AllIndicesAreConstInt &= isa<ConstantInt>(Index);
215+ bool DXILFlattenArraysVisitor::visitGetElementPtrInst (GetElementPtrInst &GEP) {
216+ // Do not visit GEPs more than once
217+ if (GEPChainInfoMap.contains (cast<GEPOperator>(&GEP)))
218+ return false ;
238219
239- if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
240- Dims.push_back (ArrayTy->getNumElements ());
241- CurrentType = ArrayTy->getElementType ();
242- } else {
243- assert (false && " Expected array type in GEP chain" );
244- }
220+ Value *PtrOperand = GEP.getPointerOperand ();
221+ // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
222+ // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
223+ // pointer types, then the handling of this case can be implemented later.
224+ assert (!isa<PHINode>(PtrOperand) &&
225+ " Pointer operand of GEP should not be a PHI Node" );
226+
227+ // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
228+ // it can be visited
229+ if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230+ PtrOpGEPCE && PtrOpGEPCE->getOpcode () == Instruction::GetElementPtr) {
231+ GetElementPtrInst *OldGEPI =
232+ cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction ());
233+ OldGEPI->insertBefore (GEP.getIterator ());
234+
235+ IRBuilder<> Builder (&GEP);
236+ SmallVector<Value *> Indices (GEP.indices ());
237+ Value *NewGEP =
238+ Builder.CreateGEP (GEP.getSourceElementType (), OldGEPI, Indices,
239+ GEP.getName (), GEP.getNoWrapFlags ());
240+ assert (isa<GetElementPtrInst>(NewGEP) &&
241+ " Expected newly-created GEP to be an instruction" );
242+ GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
243+
244+ GEP.replaceAllUsesWith (NewGEPI);
245+ GEP.eraseFromParent ();
246+ visitGetElementPtrInst (*OldGEPI);
247+ visitGetElementPtrInst (*NewGEPI);
248+ return true ;
245249 }
246- }
247-
248- void DXILFlattenArraysVisitor::recursivelyCollectGEPs (
249- GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
250- Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
251- SmallVector<uint64_t > Dims, bool AllIndicesAreConstInt) {
252- // Check if this GEP is already in the map to avoid circular references
253- if (GEPChainMap.count (&CurrGEP) > 0 )
254- return ;
255250
256- // Collect indices and dimensions from the current GEP
257- collectIndicesAndDimsFromGEP (CurrGEP, Indices, Dims, AllIndicesAreConstInt);
258- bool IsMultiDimArr = isMultiDimensionalArray (CurrGEP.getSourceElementType ());
259- if (!IsMultiDimArr) {
260- assert (GEPChainUseCount < FlattenedArrayType->getNumElements ());
261- GEPChainMap.insert (
262- {&CurrGEP,
263- {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
264- std::move (Dims), AllIndicesAreConstInt}});
265- return ;
266- }
267- bool GepUses = false ;
268- for (auto *User : CurrGEP.users ()) {
269- if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
270- recursivelyCollectGEPs (*NestedGEP, FlattenedArrayType, PtrOperand,
271- ++GEPChainUseCount, Indices, Dims,
272- AllIndicesAreConstInt);
273- GepUses = true ;
274- }
275- }
276- // This case is just incase the gep chain doesn't end with a 1d array.
277- if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
278- GEPChainMap.insert (
279- {&CurrGEP,
280- {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
281- std::move (Dims), AllIndicesAreConstInt}});
251+ // Construct GEPInfo for this GEP
252+ GEPInfo Info;
253+
254+ // Obtain the variable and constant byte offsets computed by this GEP
255+ const DataLayout &DL = GEP.getDataLayout ();
256+ unsigned BitWidth = DL.getIndexTypeSizeInBits (GEP.getType ());
257+ Info.ConstantOffset = {BitWidth, 0 };
258+ [[maybe_unused]] bool Success = GEP.collectOffset (
259+ DL, BitWidth, Info.VariableOffsets , Info.ConstantOffset );
260+ assert (Success && " Failed to collect offsets for GEP" );
261+
262+ // If there is a parent GEP, inherit the root array type and pointer, and
263+ // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
264+ // chain and we need to deterine the root array type
265+ if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
266+ assert (GEPChainInfoMap.contains (PtrOpGEP) &&
267+ " Expected parent GEP to be visited before this GEP" );
268+ GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
269+ Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType ;
270+ Info.RootPointerOperand = PGEPInfo.RootPointerOperand ;
271+ for (auto &VariableOffset : PGEPInfo.VariableOffsets )
272+ Info.VariableOffsets .insert (VariableOffset);
273+ Info.ConstantOffset += PGEPInfo.ConstantOffset ;
274+ } else {
275+ Info.RootPointerOperand = PtrOperand;
276+
277+ // We should try to determine the type of the root from the pointer rather
278+ // than the GEP's source element type because this could be a scalar GEP
279+ // into an array-typed pointer from an Alloca or Global Variable.
280+ Type *RootTy = GEP.getSourceElementType ();
281+ if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
282+ if (GlobalMap.contains (GlobalVar))
283+ GlobalVar = GlobalMap[GlobalVar];
284+ Info.RootPointerOperand = GlobalVar;
285+ RootTy = GlobalVar->getValueType ();
286+ } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
287+ RootTy = Alloca->getAllocatedType ();
288+ assert (!isMultiDimensionalArray (RootTy) &&
289+ " Expected root array type to be flattened" );
290+
291+ // If the root type is not an array, we don't need to do any flattening
292+ if (!isa<ArrayType>(RootTy))
293+ return false ;
294+
295+ Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
282296 }
283- }
284297
285- bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain (
286- GetElementPtrInst &GEP) {
287- GEPData GEPInfo = GEPChainMap.at (&GEP);
288- return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
289- }
290- bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase (
291- GEPData &GEPInfo, GetElementPtrInst &GEP) {
292- IRBuilder<> Builder (&GEP);
293- Value *FlatIndex;
294- if (GEPInfo.AllIndicesAreConstInt )
295- FlatIndex = genConstFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
296- else
297- FlatIndex =
298- genInstructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
299-
300- ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType ;
301-
302- // Don't append '.flat' to an empty string. If the SSA name isn't available
303- // it could conflict with the ParentOperand's name.
304- std::string FlatName = GEP.hasName () ? GEP.getName ().str () + " .flat" : " " ;
305-
306- Value *FlatGEP = Builder.CreateGEP (FlattenedArrayType, GEPInfo.ParentOperand ,
307- {Builder.getInt32 (0 ), FlatIndex}, FlatName,
308- GEP.getNoWrapFlags ());
309-
310- // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
311- // Erase the old GEP in the map before to avoid invalid instructions
312- // and circular references.
313- GEPChainMap.erase (&GEP);
314-
315- GEP.replaceAllUsesWith (FlatGEP);
316- GEP.eraseFromParent ();
317- return true ;
318- }
319-
320- bool DXILFlattenArraysVisitor::visitGetElementPtrInst (GetElementPtrInst &GEP) {
321- auto It = GEPChainMap.find (&GEP);
322- if (It != GEPChainMap.end ())
323- return visitGetElementPtrInstInGEPChain (GEP);
324- if (!isMultiDimensionalArray (GEP.getSourceElementType ()))
325- return false ;
326-
327- ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType ());
328- IRBuilder<> Builder (&GEP);
329- auto [TotalElements, BaseType] = getElementCountAndType (ArrType);
330- ArrayType *FlattenedArrayType = ArrayType::get (BaseType, TotalElements);
331-
332- Value *PtrOperand = GEP.getPointerOperand ();
298+ // GEPs without users or GEPs with non-GEP users should be replaced such that
299+ // the chain of GEPs they are a part of are collapsed to a single GEP into a
300+ // flattened array.
301+ bool ReplaceThisGEP = GEP.users ().empty ();
302+ for (Value *User : GEP.users ())
303+ if (!isa<GetElementPtrInst>(User))
304+ ReplaceThisGEP = true ;
305+
306+ if (ReplaceThisGEP) {
307+ unsigned BytesPerElem =
308+ DL.getTypeAllocSize (Info.RootFlattenedArrayType ->getArrayElementType ());
309+ assert (isPowerOf2_32 (BytesPerElem) &&
310+ " Bytes per element should be a power of 2" );
311+
312+ // Compute the 32-bit index for this flattened GEP from the constant and
313+ // variable byte offsets in the GEPInfo
314+ IRBuilder<> Builder (&GEP);
315+ Value *ZeroIndex = Builder.getInt32 (0 );
316+ uint64_t ConstantOffset =
317+ Info.ConstantOffset .udiv (BytesPerElem).getZExtValue ();
318+ assert (ConstantOffset < UINT32_MAX &&
319+ " Constant byte offset for flat GEP index must fit within 32 bits" );
320+ Value *FlattenedIndex = Builder.getInt32 (ConstantOffset);
321+ for (auto [VarIndex, Multiplier] : Info.VariableOffsets ) {
322+ assert (Multiplier.getActiveBits () <= 32 &&
323+ " The multiplier for a flat GEP index must fit within 32 bits" );
324+ assert (VarIndex->getType ()->isIntegerTy (32 ) &&
325+ " Expected i32-typed GEP indices" );
326+ Value *VI;
327+ if (Multiplier.getZExtValue () % BytesPerElem != 0 ) {
328+ // This can happen, e.g., with i8 GEPs. To handle this we just divide
329+ // by BytesPerElem using an instruction after multiplying VarIndex by
330+ // Multiplier.
331+ VI = Builder.CreateMul (VarIndex,
332+ Builder.getInt32 (Multiplier.getZExtValue ()));
333+ VI = Builder.CreateLShr (VI, Builder.getInt32 (Log2_32 (BytesPerElem)));
334+ } else
335+ VI = Builder.CreateMul (
336+ VarIndex,
337+ Builder.getInt32 (Multiplier.getZExtValue () / BytesPerElem));
338+ FlattenedIndex = Builder.CreateAdd (FlattenedIndex, VI);
339+ }
333340
334- unsigned GEPChainUseCount = 0 ;
335- recursivelyCollectGEPs (GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
336-
337- // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
338- // Here recursion is used to get the length of the GEP chain.
339- // Handle zero uses here because there won't be an update via
340- // a child in the chain later.
341- if (GEPChainUseCount == 0 ) {
342- SmallVector<Value *> Indices;
343- SmallVector<uint64_t > Dims;
344- bool AllIndicesAreConstInt = true ;
345-
346- // Collect indices and dimensions from the GEP
347- collectIndicesAndDimsFromGEP (GEP, Indices, Dims, AllIndicesAreConstInt);
348- GEPData GEPInfo{std::move (FlattenedArrayType), PtrOperand,
349- std::move (Indices), std::move (Dims), AllIndicesAreConstInt};
350- return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
341+ // Construct a new GEP for the flattened array to replace the current GEP
342+ Value *NewGEP = Builder.CreateGEP (
343+ Info.RootFlattenedArrayType , Info.RootPointerOperand ,
344+ {ZeroIndex, FlattenedIndex}, GEP.getName (), GEP.getNoWrapFlags ());
345+
346+ // Replace the current GEP with the new GEP. Store GEPInfo into the map
347+ // for later use in case this GEP was not the end of the chain
348+ GEPChainInfoMap.insert ({cast<GEPOperator>(NewGEP), std::move (Info)});
349+ GEP.replaceAllUsesWith (NewGEP);
350+ GEP.eraseFromParent ();
351+ return true ;
351352 }
352353
354+ // This GEP is potentially dead at the end of the pass since it may not have
355+ // any users anymore after GEP chains have been collapsed. We retain store
356+ // GEPInfo for GEPs down the chain to use to compute their indices.
357+ GEPChainInfoMap.insert ({cast<GEPOperator>(&GEP), std::move (Info)});
353358 PotentiallyDeadInstrs.emplace_back (&GEP);
354359 return false ;
355360}
@@ -416,9 +421,8 @@ static Constant *transformInitializer(Constant *Init, Type *OrigType,
416421 return ConstantArray::get (FlattenedType, FlattenedElements);
417422}
418423
419- static void
420- flattenGlobalArrays (Module &M,
421- DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
424+ static void flattenGlobalArrays (
425+ Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
422426 LLVMContext &Ctx = M.getContext ();
423427 for (GlobalVariable &G : M.globals ()) {
424428 Type *OrigType = G.getValueType ();
@@ -456,9 +460,9 @@ flattenGlobalArrays(Module &M,
456460
457461static bool flattenArrays (Module &M) {
458462 bool MadeChange = false ;
459- DXILFlattenArraysVisitor Impl;
460- DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
463+ SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
461464 flattenGlobalArrays (M, GlobalMap);
465+ DXILFlattenArraysVisitor Impl (GlobalMap);
462466 for (auto &F : make_early_inc_range (M.functions ())) {
463467 if (F.isDeclaration ())
464468 continue ;
0 commit comments