Skip to content

Commit 352215c

Browse files
authored
[DirectX] Simplify and correct the flattening of GEPs in DXILFlattenArrays (#146173)
In tandem with #146800, this PR fixes #145370 This PR simplifies the logic for collapsing GEP chains and replacing GEPs to multidimensional arrays with GEPs to flattened arrays. This implementation avoids unnecessary recursion and more robustly computes the index to the flattened array by using the GEPOperator's collectOffset function, which has the side effect of allowing "i8 GEPs" and other types of GEPs to be handled naturally in the flattening / collapsing of GEP chains. Furthermore, a handful of LLVM DirectX CodeGen tests have been edited to fix incorrect GEP offsets, mismatched types (e.g., loading i32s from a an array of floats), and typos.
1 parent ec90786 commit 352215c

File tree

6 files changed

+247
-207
lines changed

6 files changed

+247
-207
lines changed

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp

Lines changed: 151 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/IR/InstVisitor.h"
2121
#include "llvm/IR/ReplaceConstant.h"
2222
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/MathExtras.h"
2324
#include "llvm/Transforms/Utils/Local.h"
2425
#include <cassert>
2526
#include <cstddef>
@@ -40,18 +41,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
4041
static char ID; // Pass identification.
4142
};
4243

43-
struct GEPData {
44-
ArrayType *ParentArrayType;
45-
Value *ParentOperand;
46-
SmallVector<Value *> Indices;
47-
SmallVector<uint64_t> Dims;
48-
bool AllIndicesAreConstInt;
44+
struct GEPInfo {
45+
ArrayType *RootFlattenedArrayType;
46+
Value *RootPointerOperand;
47+
SmallMapVector<Value *, APInt, 4> VariableOffsets;
48+
APInt ConstantOffset;
4949
};
5050

5151
class DXILFlattenArraysVisitor
5252
: public InstVisitor<DXILFlattenArraysVisitor, bool> {
5353
public:
54-
DXILFlattenArraysVisitor() {}
54+
DXILFlattenArraysVisitor(
55+
SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
56+
: GlobalMap(GlobalMap) {}
5557
bool visit(Function &F);
5658
// InstVisitor methods. They return true if the instruction was scalarized,
5759
// false if nothing changed.
@@ -78,35 +80,20 @@ class DXILFlattenArraysVisitor
7880

7981
private:
8082
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81-
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
83+
SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
84+
SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
8285
bool finish();
8386
ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
8487
ArrayRef<uint64_t> Dims,
8588
IRBuilder<> &Builder);
8689
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
8790
ArrayRef<uint64_t> Dims,
8891
IRBuilder<> &Builder);
89-
90-
// Helper function to collect indices and dimensions from a GEP instruction
91-
void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP,
92-
SmallVectorImpl<Value *> &Indices,
93-
SmallVectorImpl<uint64_t> &Dims,
94-
bool &AllIndicesAreConstInt);
95-
96-
void
97-
recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
98-
ArrayType *FlattenedArrayType, Value *PtrOperand,
99-
unsigned &GEPChainUseCount,
100-
SmallVector<Value *> Indices = SmallVector<Value *>(),
101-
SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
102-
bool AllIndicesAreConstInt = true);
103-
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
104-
bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
105-
GetElementPtrInst &GEP);
10692
};
10793
} // namespace
10894

10995
bool DXILFlattenArraysVisitor::finish() {
96+
GEPChainInfoMap.clear();
11097
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
11198
return true;
11299
}
@@ -225,131 +212,149 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
225212
return true;
226213
}
227214

228-
void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
229-
GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
230-
SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
231-
232-
Type *CurrentType = GEP.getSourceElementType();
233-
234-
// Note index 0 is the ptr index.
235-
for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
236-
Indices.push_back(Index);
237-
AllIndicesAreConstInt &= isa<ConstantInt>(Index);
215+
bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
216+
// Do not visit GEPs more than once
217+
if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
218+
return false;
238219

239-
if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
240-
Dims.push_back(ArrayTy->getNumElements());
241-
CurrentType = ArrayTy->getElementType();
242-
} else {
243-
assert(false && "Expected array type in GEP chain");
244-
}
220+
Value *PtrOperand = GEP.getPointerOperand();
221+
// It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
222+
// node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
223+
// pointer types, then the handling of this case can be implemented later.
224+
assert(!isa<PHINode>(PtrOperand) &&
225+
"Pointer operand of GEP should not be a PHI Node");
226+
227+
// Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
228+
// it can be visited
229+
if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230+
PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
231+
GetElementPtrInst *OldGEPI =
232+
cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
233+
OldGEPI->insertBefore(GEP.getIterator());
234+
235+
IRBuilder<> Builder(&GEP);
236+
SmallVector<Value *> Indices(GEP.indices());
237+
Value *NewGEP =
238+
Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
239+
GEP.getName(), GEP.getNoWrapFlags());
240+
assert(isa<GetElementPtrInst>(NewGEP) &&
241+
"Expected newly-created GEP to be an instruction");
242+
GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
243+
244+
GEP.replaceAllUsesWith(NewGEPI);
245+
GEP.eraseFromParent();
246+
visitGetElementPtrInst(*OldGEPI);
247+
visitGetElementPtrInst(*NewGEPI);
248+
return true;
245249
}
246-
}
247-
248-
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
249-
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
250-
Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
251-
SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
252-
// Check if this GEP is already in the map to avoid circular references
253-
if (GEPChainMap.count(&CurrGEP) > 0)
254-
return;
255250

256-
// Collect indices and dimensions from the current GEP
257-
collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
258-
bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
259-
if (!IsMultiDimArr) {
260-
assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
261-
GEPChainMap.insert(
262-
{&CurrGEP,
263-
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
264-
std::move(Dims), AllIndicesAreConstInt}});
265-
return;
266-
}
267-
bool GepUses = false;
268-
for (auto *User : CurrGEP.users()) {
269-
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
270-
recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
271-
++GEPChainUseCount, Indices, Dims,
272-
AllIndicesAreConstInt);
273-
GepUses = true;
274-
}
275-
}
276-
// This case is just incase the gep chain doesn't end with a 1d array.
277-
if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
278-
GEPChainMap.insert(
279-
{&CurrGEP,
280-
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
281-
std::move(Dims), AllIndicesAreConstInt}});
251+
// Construct GEPInfo for this GEP
252+
GEPInfo Info;
253+
254+
// Obtain the variable and constant byte offsets computed by this GEP
255+
const DataLayout &DL = GEP.getDataLayout();
256+
unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
257+
Info.ConstantOffset = {BitWidth, 0};
258+
[[maybe_unused]] bool Success = GEP.collectOffset(
259+
DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);
260+
assert(Success && "Failed to collect offsets for GEP");
261+
262+
// If there is a parent GEP, inherit the root array type and pointer, and
263+
// merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
264+
// chain and we need to deterine the root array type
265+
if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
266+
assert(GEPChainInfoMap.contains(PtrOpGEP) &&
267+
"Expected parent GEP to be visited before this GEP");
268+
GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
269+
Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
270+
Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
271+
for (auto &VariableOffset : PGEPInfo.VariableOffsets)
272+
Info.VariableOffsets.insert(VariableOffset);
273+
Info.ConstantOffset += PGEPInfo.ConstantOffset;
274+
} else {
275+
Info.RootPointerOperand = PtrOperand;
276+
277+
// We should try to determine the type of the root from the pointer rather
278+
// than the GEP's source element type because this could be a scalar GEP
279+
// into an array-typed pointer from an Alloca or Global Variable.
280+
Type *RootTy = GEP.getSourceElementType();
281+
if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
282+
if (GlobalMap.contains(GlobalVar))
283+
GlobalVar = GlobalMap[GlobalVar];
284+
Info.RootPointerOperand = GlobalVar;
285+
RootTy = GlobalVar->getValueType();
286+
} else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
287+
RootTy = Alloca->getAllocatedType();
288+
assert(!isMultiDimensionalArray(RootTy) &&
289+
"Expected root array type to be flattened");
290+
291+
// If the root type is not an array, we don't need to do any flattening
292+
if (!isa<ArrayType>(RootTy))
293+
return false;
294+
295+
Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
282296
}
283-
}
284297

285-
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
286-
GetElementPtrInst &GEP) {
287-
GEPData GEPInfo = GEPChainMap.at(&GEP);
288-
return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
289-
}
290-
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
291-
GEPData &GEPInfo, GetElementPtrInst &GEP) {
292-
IRBuilder<> Builder(&GEP);
293-
Value *FlatIndex;
294-
if (GEPInfo.AllIndicesAreConstInt)
295-
FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
296-
else
297-
FlatIndex =
298-
genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
299-
300-
ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
301-
302-
// Don't append '.flat' to an empty string. If the SSA name isn't available
303-
// it could conflict with the ParentOperand's name.
304-
std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : "";
305-
306-
Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand,
307-
{Builder.getInt32(0), FlatIndex}, FlatName,
308-
GEP.getNoWrapFlags());
309-
310-
// Note: Old gep will become an invalid instruction after replaceAllUsesWith.
311-
// Erase the old GEP in the map before to avoid invalid instructions
312-
// and circular references.
313-
GEPChainMap.erase(&GEP);
314-
315-
GEP.replaceAllUsesWith(FlatGEP);
316-
GEP.eraseFromParent();
317-
return true;
318-
}
319-
320-
bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
321-
auto It = GEPChainMap.find(&GEP);
322-
if (It != GEPChainMap.end())
323-
return visitGetElementPtrInstInGEPChain(GEP);
324-
if (!isMultiDimensionalArray(GEP.getSourceElementType()))
325-
return false;
326-
327-
ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
328-
IRBuilder<> Builder(&GEP);
329-
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
330-
ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
331-
332-
Value *PtrOperand = GEP.getPointerOperand();
298+
// GEPs without users or GEPs with non-GEP users should be replaced such that
299+
// the chain of GEPs they are a part of are collapsed to a single GEP into a
300+
// flattened array.
301+
bool ReplaceThisGEP = GEP.users().empty();
302+
for (Value *User : GEP.users())
303+
if (!isa<GetElementPtrInst>(User))
304+
ReplaceThisGEP = true;
305+
306+
if (ReplaceThisGEP) {
307+
unsigned BytesPerElem =
308+
DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
309+
assert(isPowerOf2_32(BytesPerElem) &&
310+
"Bytes per element should be a power of 2");
311+
312+
// Compute the 32-bit index for this flattened GEP from the constant and
313+
// variable byte offsets in the GEPInfo
314+
IRBuilder<> Builder(&GEP);
315+
Value *ZeroIndex = Builder.getInt32(0);
316+
uint64_t ConstantOffset =
317+
Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
318+
assert(ConstantOffset < UINT32_MAX &&
319+
"Constant byte offset for flat GEP index must fit within 32 bits");
320+
Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
321+
for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
322+
assert(Multiplier.getActiveBits() <= 32 &&
323+
"The multiplier for a flat GEP index must fit within 32 bits");
324+
assert(VarIndex->getType()->isIntegerTy(32) &&
325+
"Expected i32-typed GEP indices");
326+
Value *VI;
327+
if (Multiplier.getZExtValue() % BytesPerElem != 0) {
328+
// This can happen, e.g., with i8 GEPs. To handle this we just divide
329+
// by BytesPerElem using an instruction after multiplying VarIndex by
330+
// Multiplier.
331+
VI = Builder.CreateMul(VarIndex,
332+
Builder.getInt32(Multiplier.getZExtValue()));
333+
VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
334+
} else
335+
VI = Builder.CreateMul(
336+
VarIndex,
337+
Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
338+
FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
339+
}
333340

334-
unsigned GEPChainUseCount = 0;
335-
recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
336-
337-
// NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
338-
// Here recursion is used to get the length of the GEP chain.
339-
// Handle zero uses here because there won't be an update via
340-
// a child in the chain later.
341-
if (GEPChainUseCount == 0) {
342-
SmallVector<Value *> Indices;
343-
SmallVector<uint64_t> Dims;
344-
bool AllIndicesAreConstInt = true;
345-
346-
// Collect indices and dimensions from the GEP
347-
collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
348-
GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
349-
std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
350-
return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
341+
// Construct a new GEP for the flattened array to replace the current GEP
342+
Value *NewGEP = Builder.CreateGEP(
343+
Info.RootFlattenedArrayType, Info.RootPointerOperand,
344+
{ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
345+
346+
// Replace the current GEP with the new GEP. Store GEPInfo into the map
347+
// for later use in case this GEP was not the end of the chain
348+
GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
349+
GEP.replaceAllUsesWith(NewGEP);
350+
GEP.eraseFromParent();
351+
return true;
351352
}
352353

354+
// This GEP is potentially dead at the end of the pass since it may not have
355+
// any users anymore after GEP chains have been collapsed. We retain store
356+
// GEPInfo for GEPs down the chain to use to compute their indices.
357+
GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
353358
PotentiallyDeadInstrs.emplace_back(&GEP);
354359
return false;
355360
}
@@ -416,9 +421,8 @@ static Constant *transformInitializer(Constant *Init, Type *OrigType,
416421
return ConstantArray::get(FlattenedType, FlattenedElements);
417422
}
418423

419-
static void
420-
flattenGlobalArrays(Module &M,
421-
DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
424+
static void flattenGlobalArrays(
425+
Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
422426
LLVMContext &Ctx = M.getContext();
423427
for (GlobalVariable &G : M.globals()) {
424428
Type *OrigType = G.getValueType();
@@ -456,9 +460,9 @@ flattenGlobalArrays(Module &M,
456460

457461
static bool flattenArrays(Module &M) {
458462
bool MadeChange = false;
459-
DXILFlattenArraysVisitor Impl;
460-
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
463+
SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
461464
flattenGlobalArrays(M, GlobalMap);
465+
DXILFlattenArraysVisitor Impl(GlobalMap);
462466
for (auto &F : make_early_inc_range(M.functions())) {
463467
if (F.isDeclaration())
464468
continue;

0 commit comments

Comments
 (0)