20
20
#include " llvm/IR/InstVisitor.h"
21
21
#include " llvm/IR/ReplaceConstant.h"
22
22
#include " llvm/Support/Casting.h"
23
+ #include " llvm/Support/MathExtras.h"
23
24
#include " llvm/Transforms/Utils/Local.h"
24
25
#include < cassert>
25
26
#include < cstddef>
@@ -40,18 +41,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
40
41
static char ID; // Pass identification.
41
42
};
42
43
43
- struct GEPData {
44
- ArrayType *ParentArrayType;
45
- Value *ParentOperand;
46
- SmallVector<Value *> Indices;
47
- SmallVector<uint64_t > Dims;
48
- bool AllIndicesAreConstInt;
44
+ struct GEPInfo {
45
+ ArrayType *RootFlattenedArrayType;
46
+ Value *RootPointerOperand;
47
+ SmallMapVector<Value *, APInt, 4 > VariableOffsets;
48
+ APInt ConstantOffset;
49
49
};
50
50
51
51
class DXILFlattenArraysVisitor
52
52
: public InstVisitor<DXILFlattenArraysVisitor, bool > {
53
53
public:
54
- DXILFlattenArraysVisitor () {}
54
+ DXILFlattenArraysVisitor (
55
+ SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
56
+ : GlobalMap(GlobalMap) {}
55
57
bool visit (Function &F);
56
58
// InstVisitor methods. They return true if the instruction was scalarized,
57
59
// false if nothing changed.
@@ -78,35 +80,20 @@ class DXILFlattenArraysVisitor
78
80
79
81
private:
80
82
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81
- DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
83
+ SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
84
+ SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
82
85
bool finish ();
83
86
ConstantInt *genConstFlattenIndices (ArrayRef<Value *> Indices,
84
87
ArrayRef<uint64_t > Dims,
85
88
IRBuilder<> &Builder);
86
89
Value *genInstructionFlattenIndices (ArrayRef<Value *> Indices,
87
90
ArrayRef<uint64_t > Dims,
88
91
IRBuilder<> &Builder);
89
-
90
- // Helper function to collect indices and dimensions from a GEP instruction
91
- void collectIndicesAndDimsFromGEP (GetElementPtrInst &GEP,
92
- SmallVectorImpl<Value *> &Indices,
93
- SmallVectorImpl<uint64_t > &Dims,
94
- bool &AllIndicesAreConstInt);
95
-
96
- void
97
- recursivelyCollectGEPs (GetElementPtrInst &CurrGEP,
98
- ArrayType *FlattenedArrayType, Value *PtrOperand,
99
- unsigned &GEPChainUseCount,
100
- SmallVector<Value *> Indices = SmallVector<Value *>(),
101
- SmallVector<uint64_t > Dims = SmallVector<uint64_t >(),
102
- bool AllIndicesAreConstInt = true );
103
- bool visitGetElementPtrInstInGEPChain (GetElementPtrInst &GEP);
104
- bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
105
- GetElementPtrInst &GEP);
106
92
};
107
93
} // namespace
108
94
109
95
bool DXILFlattenArraysVisitor::finish () {
96
+ GEPChainInfoMap.clear ();
110
97
RecursivelyDeleteTriviallyDeadInstructionsPermissive (PotentiallyDeadInstrs);
111
98
return true ;
112
99
}
@@ -225,131 +212,149 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
225
212
return true ;
226
213
}
227
214
228
- void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP (
229
- GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
230
- SmallVectorImpl<uint64_t > &Dims, bool &AllIndicesAreConstInt) {
231
-
232
- Type *CurrentType = GEP.getSourceElementType ();
233
-
234
- // Note index 0 is the ptr index.
235
- for (Value *Index : llvm::drop_begin (GEP.indices (), 1 )) {
236
- Indices.push_back (Index);
237
- AllIndicesAreConstInt &= isa<ConstantInt>(Index);
215
+ bool DXILFlattenArraysVisitor::visitGetElementPtrInst (GetElementPtrInst &GEP) {
216
+ // Do not visit GEPs more than once
217
+ if (GEPChainInfoMap.contains (cast<GEPOperator>(&GEP)))
218
+ return false ;
238
219
239
- if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
240
- Dims.push_back (ArrayTy->getNumElements ());
241
- CurrentType = ArrayTy->getElementType ();
242
- } else {
243
- assert (false && " Expected array type in GEP chain" );
244
- }
220
+ Value *PtrOperand = GEP.getPointerOperand ();
221
+ // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
222
+ // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
223
+ // pointer types, then the handling of this case can be implemented later.
224
+ assert (!isa<PHINode>(PtrOperand) &&
225
+ " Pointer operand of GEP should not be a PHI Node" );
226
+
227
+ // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
228
+ // it can be visited
229
+ if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230
+ PtrOpGEPCE && PtrOpGEPCE->getOpcode () == Instruction::GetElementPtr) {
231
+ GetElementPtrInst *OldGEPI =
232
+ cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction ());
233
+ OldGEPI->insertBefore (GEP.getIterator ());
234
+
235
+ IRBuilder<> Builder (&GEP);
236
+ SmallVector<Value *> Indices (GEP.indices ());
237
+ Value *NewGEP =
238
+ Builder.CreateGEP (GEP.getSourceElementType (), OldGEPI, Indices,
239
+ GEP.getName (), GEP.getNoWrapFlags ());
240
+ assert (isa<GetElementPtrInst>(NewGEP) &&
241
+ " Expected newly-created GEP to be an instruction" );
242
+ GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
243
+
244
+ GEP.replaceAllUsesWith (NewGEPI);
245
+ GEP.eraseFromParent ();
246
+ visitGetElementPtrInst (*OldGEPI);
247
+ visitGetElementPtrInst (*NewGEPI);
248
+ return true ;
245
249
}
246
- }
247
-
248
- void DXILFlattenArraysVisitor::recursivelyCollectGEPs (
249
- GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
250
- Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
251
- SmallVector<uint64_t > Dims, bool AllIndicesAreConstInt) {
252
- // Check if this GEP is already in the map to avoid circular references
253
- if (GEPChainMap.count (&CurrGEP) > 0 )
254
- return ;
255
250
256
- // Collect indices and dimensions from the current GEP
257
- collectIndicesAndDimsFromGEP (CurrGEP, Indices, Dims, AllIndicesAreConstInt);
258
- bool IsMultiDimArr = isMultiDimensionalArray (CurrGEP.getSourceElementType ());
259
- if (!IsMultiDimArr) {
260
- assert (GEPChainUseCount < FlattenedArrayType->getNumElements ());
261
- GEPChainMap.insert (
262
- {&CurrGEP,
263
- {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
264
- std::move (Dims), AllIndicesAreConstInt}});
265
- return ;
266
- }
267
- bool GepUses = false ;
268
- for (auto *User : CurrGEP.users ()) {
269
- if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
270
- recursivelyCollectGEPs (*NestedGEP, FlattenedArrayType, PtrOperand,
271
- ++GEPChainUseCount, Indices, Dims,
272
- AllIndicesAreConstInt);
273
- GepUses = true ;
274
- }
275
- }
276
- // This case is just incase the gep chain doesn't end with a 1d array.
277
- if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
278
- GEPChainMap.insert (
279
- {&CurrGEP,
280
- {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
281
- std::move (Dims), AllIndicesAreConstInt}});
251
+ // Construct GEPInfo for this GEP
252
+ GEPInfo Info;
253
+
254
+ // Obtain the variable and constant byte offsets computed by this GEP
255
+ const DataLayout &DL = GEP.getDataLayout ();
256
+ unsigned BitWidth = DL.getIndexTypeSizeInBits (GEP.getType ());
257
+ Info.ConstantOffset = {BitWidth, 0 };
258
+ [[maybe_unused]] bool Success = GEP.collectOffset (
259
+ DL, BitWidth, Info.VariableOffsets , Info.ConstantOffset );
260
+ assert (Success && " Failed to collect offsets for GEP" );
261
+
262
+ // If there is a parent GEP, inherit the root array type and pointer, and
263
+ // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
264
+ // chain and we need to deterine the root array type
265
+ if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
266
+ assert (GEPChainInfoMap.contains (PtrOpGEP) &&
267
+ " Expected parent GEP to be visited before this GEP" );
268
+ GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
269
+ Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType ;
270
+ Info.RootPointerOperand = PGEPInfo.RootPointerOperand ;
271
+ for (auto &VariableOffset : PGEPInfo.VariableOffsets )
272
+ Info.VariableOffsets .insert (VariableOffset);
273
+ Info.ConstantOffset += PGEPInfo.ConstantOffset ;
274
+ } else {
275
+ Info.RootPointerOperand = PtrOperand;
276
+
277
+ // We should try to determine the type of the root from the pointer rather
278
+ // than the GEP's source element type because this could be a scalar GEP
279
+ // into an array-typed pointer from an Alloca or Global Variable.
280
+ Type *RootTy = GEP.getSourceElementType ();
281
+ if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
282
+ if (GlobalMap.contains (GlobalVar))
283
+ GlobalVar = GlobalMap[GlobalVar];
284
+ Info.RootPointerOperand = GlobalVar;
285
+ RootTy = GlobalVar->getValueType ();
286
+ } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
287
+ RootTy = Alloca->getAllocatedType ();
288
+ assert (!isMultiDimensionalArray (RootTy) &&
289
+ " Expected root array type to be flattened" );
290
+
291
+ // If the root type is not an array, we don't need to do any flattening
292
+ if (!isa<ArrayType>(RootTy))
293
+ return false ;
294
+
295
+ Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
282
296
}
283
- }
284
297
285
- bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain (
286
- GetElementPtrInst &GEP) {
287
- GEPData GEPInfo = GEPChainMap.at (&GEP);
288
- return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
289
- }
290
- bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase (
291
- GEPData &GEPInfo, GetElementPtrInst &GEP) {
292
- IRBuilder<> Builder (&GEP);
293
- Value *FlatIndex;
294
- if (GEPInfo.AllIndicesAreConstInt )
295
- FlatIndex = genConstFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
296
- else
297
- FlatIndex =
298
- genInstructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
299
-
300
- ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType ;
301
-
302
- // Don't append '.flat' to an empty string. If the SSA name isn't available
303
- // it could conflict with the ParentOperand's name.
304
- std::string FlatName = GEP.hasName () ? GEP.getName ().str () + " .flat" : " " ;
305
-
306
- Value *FlatGEP = Builder.CreateGEP (FlattenedArrayType, GEPInfo.ParentOperand ,
307
- {Builder.getInt32 (0 ), FlatIndex}, FlatName,
308
- GEP.getNoWrapFlags ());
309
-
310
- // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
311
- // Erase the old GEP in the map before to avoid invalid instructions
312
- // and circular references.
313
- GEPChainMap.erase (&GEP);
314
-
315
- GEP.replaceAllUsesWith (FlatGEP);
316
- GEP.eraseFromParent ();
317
- return true ;
318
- }
319
-
320
- bool DXILFlattenArraysVisitor::visitGetElementPtrInst (GetElementPtrInst &GEP) {
321
- auto It = GEPChainMap.find (&GEP);
322
- if (It != GEPChainMap.end ())
323
- return visitGetElementPtrInstInGEPChain (GEP);
324
- if (!isMultiDimensionalArray (GEP.getSourceElementType ()))
325
- return false ;
326
-
327
- ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType ());
328
- IRBuilder<> Builder (&GEP);
329
- auto [TotalElements, BaseType] = getElementCountAndType (ArrType);
330
- ArrayType *FlattenedArrayType = ArrayType::get (BaseType, TotalElements);
331
-
332
- Value *PtrOperand = GEP.getPointerOperand ();
298
+ // GEPs without users or GEPs with non-GEP users should be replaced such that
299
+ // the chain of GEPs they are a part of are collapsed to a single GEP into a
300
+ // flattened array.
301
+ bool ReplaceThisGEP = GEP.users ().empty ();
302
+ for (Value *User : GEP.users ())
303
+ if (!isa<GetElementPtrInst>(User))
304
+ ReplaceThisGEP = true ;
305
+
306
+ if (ReplaceThisGEP) {
307
+ unsigned BytesPerElem =
308
+ DL.getTypeAllocSize (Info.RootFlattenedArrayType ->getArrayElementType ());
309
+ assert (isPowerOf2_32 (BytesPerElem) &&
310
+ " Bytes per element should be a power of 2" );
311
+
312
+ // Compute the 32-bit index for this flattened GEP from the constant and
313
+ // variable byte offsets in the GEPInfo
314
+ IRBuilder<> Builder (&GEP);
315
+ Value *ZeroIndex = Builder.getInt32 (0 );
316
+ uint64_t ConstantOffset =
317
+ Info.ConstantOffset .udiv (BytesPerElem).getZExtValue ();
318
+ assert (ConstantOffset < UINT32_MAX &&
319
+ " Constant byte offset for flat GEP index must fit within 32 bits" );
320
+ Value *FlattenedIndex = Builder.getInt32 (ConstantOffset);
321
+ for (auto [VarIndex, Multiplier] : Info.VariableOffsets ) {
322
+ assert (Multiplier.getActiveBits () <= 32 &&
323
+ " The multiplier for a flat GEP index must fit within 32 bits" );
324
+ assert (VarIndex->getType ()->isIntegerTy (32 ) &&
325
+ " Expected i32-typed GEP indices" );
326
+ Value *VI;
327
+ if (Multiplier.getZExtValue () % BytesPerElem != 0 ) {
328
+ // This can happen, e.g., with i8 GEPs. To handle this we just divide
329
+ // by BytesPerElem using an instruction after multiplying VarIndex by
330
+ // Multiplier.
331
+ VI = Builder.CreateMul (VarIndex,
332
+ Builder.getInt32 (Multiplier.getZExtValue ()));
333
+ VI = Builder.CreateLShr (VI, Builder.getInt32 (Log2_32 (BytesPerElem)));
334
+ } else
335
+ VI = Builder.CreateMul (
336
+ VarIndex,
337
+ Builder.getInt32 (Multiplier.getZExtValue () / BytesPerElem));
338
+ FlattenedIndex = Builder.CreateAdd (FlattenedIndex, VI);
339
+ }
333
340
334
- unsigned GEPChainUseCount = 0 ;
335
- recursivelyCollectGEPs (GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
336
-
337
- // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
338
- // Here recursion is used to get the length of the GEP chain.
339
- // Handle zero uses here because there won't be an update via
340
- // a child in the chain later.
341
- if (GEPChainUseCount == 0 ) {
342
- SmallVector<Value *> Indices;
343
- SmallVector<uint64_t > Dims;
344
- bool AllIndicesAreConstInt = true ;
345
-
346
- // Collect indices and dimensions from the GEP
347
- collectIndicesAndDimsFromGEP (GEP, Indices, Dims, AllIndicesAreConstInt);
348
- GEPData GEPInfo{std::move (FlattenedArrayType), PtrOperand,
349
- std::move (Indices), std::move (Dims), AllIndicesAreConstInt};
350
- return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
341
+ // Construct a new GEP for the flattened array to replace the current GEP
342
+ Value *NewGEP = Builder.CreateGEP (
343
+ Info.RootFlattenedArrayType , Info.RootPointerOperand ,
344
+ {ZeroIndex, FlattenedIndex}, GEP.getName (), GEP.getNoWrapFlags ());
345
+
346
+ // Replace the current GEP with the new GEP. Store GEPInfo into the map
347
+ // for later use in case this GEP was not the end of the chain
348
+ GEPChainInfoMap.insert ({cast<GEPOperator>(NewGEP), std::move (Info)});
349
+ GEP.replaceAllUsesWith (NewGEP);
350
+ GEP.eraseFromParent ();
351
+ return true ;
351
352
}
352
353
354
+ // This GEP is potentially dead at the end of the pass since it may not have
355
+ // any users anymore after GEP chains have been collapsed. We retain store
356
+ // GEPInfo for GEPs down the chain to use to compute their indices.
357
+ GEPChainInfoMap.insert ({cast<GEPOperator>(&GEP), std::move (Info)});
353
358
PotentiallyDeadInstrs.emplace_back (&GEP);
354
359
return false ;
355
360
}
@@ -416,9 +421,8 @@ static Constant *transformInitializer(Constant *Init, Type *OrigType,
416
421
return ConstantArray::get (FlattenedType, FlattenedElements);
417
422
}
418
423
419
- static void
420
- flattenGlobalArrays (Module &M,
421
- DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
424
+ static void flattenGlobalArrays (
425
+ Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
422
426
LLVMContext &Ctx = M.getContext ();
423
427
for (GlobalVariable &G : M.globals ()) {
424
428
Type *OrigType = G.getValueType ();
@@ -456,9 +460,9 @@ flattenGlobalArrays(Module &M,
456
460
457
461
static bool flattenArrays (Module &M) {
458
462
bool MadeChange = false ;
459
- DXILFlattenArraysVisitor Impl;
460
- DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
463
+ SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
461
464
flattenGlobalArrays (M, GlobalMap);
465
+ DXILFlattenArraysVisitor Impl (GlobalMap);
462
466
for (auto &F : make_early_inc_range (M.functions ())) {
463
467
if (F.isDeclaration ())
464
468
continue ;
0 commit comments