55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===---------------------------------------------------------------------===//
8-
98// /
109// / \file This file contains a pass to flatten arrays for the DirectX Backend.
11- //
10+ // /
1211// ===----------------------------------------------------------------------===//
1312
1413#include " DXILFlattenArrays.h"
2625#include < cassert>
2726#include < cstddef>
2827#include < cstdint>
28+ #include < utility>
2929
3030#define DEBUG_TYPE " dxil-flatten-arrays"
3131
3232using namespace llvm ;
33+ namespace {
3334
3435class DXILFlattenArraysLegacy : public ModulePass {
3536
@@ -75,19 +76,18 @@ class DXILFlattenArraysVisitor
7576 bool visitCallInst (CallInst &ICI) { return false ; }
7677 bool visitFreezeInst (FreezeInst &FI) { return false ; }
7778 static bool isMultiDimensionalArray (Type *T);
78- static unsigned getTotalElements (Type *ArrayTy);
79- static Type *getBaseElementType (Type *ArrayTy);
79+ static std::pair<unsigned , Type *> getElementCountAndType (Type *ArrayTy);
8080
8181private:
82- SmallVector<WeakTrackingVH, 32 > PotentiallyDeadInstrs;
82+ SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
8383 DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
8484 bool finish ();
85- ConstantInt *constFlattenIndices (ArrayRef<Value *> Indices,
86- ArrayRef<uint64_t > Dims,
87- IRBuilder<> &Builder);
88- Value *instructionFlattenIndices (ArrayRef<Value *> Indices,
89- ArrayRef<uint64_t > Dims,
90- IRBuilder<> &Builder);
85+ ConstantInt *genConstFlattenIndices (ArrayRef<Value *> Indices,
86+ ArrayRef<uint64_t > Dims,
87+ IRBuilder<> &Builder);
88+ Value *genInstructionFlattenIndices (ArrayRef<Value *> Indices,
89+ ArrayRef<uint64_t > Dims,
90+ IRBuilder<> &Builder);
9191 void
9292 recursivelyCollectGEPs (GetElementPtrInst &CurrGEP,
9393 ArrayType *FlattenedArrayType, Value *PtrOperand,
@@ -99,6 +99,7 @@ class DXILFlattenArraysVisitor
9999 bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
100100 GetElementPtrInst &GEP);
101101};
102+ } // namespace
102103
103104bool DXILFlattenArraysVisitor::finish () {
104105 RecursivelyDeleteTriviallyDeadInstructionsPermissive (PotentiallyDeadInstrs);
@@ -111,25 +112,18 @@ bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
111112 return false ;
112113}
113114
114- unsigned DXILFlattenArraysVisitor::getTotalElements (Type *ArrayTy) {
115+ std::pair<unsigned , Type *>
116+ DXILFlattenArraysVisitor::getElementCountAndType (Type *ArrayTy) {
115117 unsigned TotalElements = 1 ;
116118 Type *CurrArrayTy = ArrayTy;
117119 while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
118120 TotalElements *= InnerArrayTy->getNumElements ();
119121 CurrArrayTy = InnerArrayTy->getElementType ();
120122 }
121- return TotalElements;
122- }
123-
124- Type *DXILFlattenArraysVisitor::getBaseElementType (Type *ArrayTy) {
125- Type *CurrArrayTy = ArrayTy;
126- while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
127- CurrArrayTy = InnerArrayTy->getElementType ();
128- }
129- return CurrArrayTy;
123+ return std::make_pair (TotalElements, CurrArrayTy);
130124}
131125
132- ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices (
126+ ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices (
133127 ArrayRef<Value *> Indices, ArrayRef<uint64_t > Dims, IRBuilder<> &Builder) {
134128 assert (Indices.size () == Dims.size () &&
135129 " Indicies and dimmensions should be the same" );
@@ -146,7 +140,7 @@ ConstantInt *DXILFlattenArraysVisitor::constFlattenIndices(
146140 return Builder.getInt32 (FlatIndex);
147141}
148142
149- Value *DXILFlattenArraysVisitor::instructionFlattenIndices (
143+ Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices (
150144 ArrayRef<Value *> Indices, ArrayRef<uint64_t > Dims, IRBuilder<> &Builder) {
151145 if (Indices.size () == 1 )
152146 return Indices[0 ];
@@ -202,10 +196,9 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
202196
203197 ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
204198 IRBuilder<> Builder (&AI);
205- unsigned TotalElements = getTotalElements (ArrType);
199+ auto [ TotalElements, BaseType] = getElementCountAndType (ArrType);
206200
207- ArrayType *FattenedArrayType =
208- ArrayType::get (getBaseElementType (ArrType), TotalElements);
201+ ArrayType *FattenedArrayType = ArrayType::get (BaseType, TotalElements);
209202 AllocaInst *FlatAlloca =
210203 Builder.CreateAlloca (FattenedArrayType, nullptr , AI.getName () + " .flat" );
211204 FlatAlloca->setAlignment (AI.getAlign ());
@@ -261,10 +254,10 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
261254 IRBuilder<> Builder (&GEP);
262255 Value *FlatIndex;
263256 if (GEPInfo.AllIndicesAreConstInt )
264- FlatIndex = constFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
257+ FlatIndex = genConstFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
265258 else
266259 FlatIndex =
267- instructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
260+ genInstructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
268261
269262 ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType ;
270263 Value *FlatGEP =
@@ -285,9 +278,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
285278
286279 ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType ());
287280 IRBuilder<> Builder (&GEP);
288- unsigned TotalElements = getTotalElements (ArrType);
289- ArrayType *FlattenedArrayType =
290- ArrayType::get (getBaseElementType (ArrType), TotalElements);
281+ auto [TotalElements, BaseType] = getElementCountAndType (ArrType);
282+ ArrayType *FlattenedArrayType = ArrayType::get (BaseType, TotalElements);
291283
292284 Value *PtrOperand = GEP.getPointerOperand ();
293285
@@ -313,7 +305,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
313305
314306bool DXILFlattenArraysVisitor::visit (Function &F) {
315307 bool MadeChange = false ;
316- // //for (BasicBlock &BB : make_early_inc_range(F)) {
317308 ReversePostOrderTraversal<Function *> RPOT (&F);
318309 for (BasicBlock *BB : make_early_inc_range (RPOT)) {
319310 for (Instruction &I : make_early_inc_range (*BB)) {
@@ -345,8 +336,7 @@ static void collectElements(Constant *Init,
345336 collectElements (DataArrayConstant->getElementAsConstant (I), Elements);
346337 }
347338 } else {
348- assert (
349- false &&
339+ llvm_unreachable (
350340 " Expected a ConstantArray or ConstantDataArray for array initializer!" );
351341 }
352342}
@@ -382,10 +372,9 @@ flattenGlobalArrays(Module &M,
382372 continue ;
383373
384374 ArrayType *ArrType = cast<ArrayType>(OrigType);
385- unsigned TotalElements =
386- DXILFlattenArraysVisitor::getTotalElements (ArrType);
387- ArrayType *FattenedArrayType = ArrayType::get (
388- DXILFlattenArraysVisitor::getBaseElementType (ArrType), TotalElements);
375+ auto [TotalElements, BaseType] =
376+ DXILFlattenArraysVisitor::getElementCountAndType (ArrType);
377+ ArrayType *FattenedArrayType = ArrayType::get (BaseType, TotalElements);
389378
390379 // Create a new global variable with the updated type
391380 // Note: Initializer is set via transformInitializer
0 commit comments