1010#include " DirectX.h"
1111#include " llvm/ADT/PostOrderIterator.h"
1212#include " llvm/ADT/STLExtras.h"
13+ #include " llvm/IR/DerivedTypes.h"
1314#include " llvm/IR/GlobalVariable.h"
1415#include " llvm/IR/IRBuilder.h"
1516#include " llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
4041class DataScalarizerVisitor : public InstVisitor <DataScalarizerVisitor, bool > {
4142public:
4243 DataScalarizerVisitor () : GlobalMap() {}
43- bool visit (Instruction &I );
44+ bool visit (Function &F );
4445 // InstVisitor methods. They return true if the instruction was scalarized,
4546 // false if nothing changed.
47+ bool visitAllocaInst (AllocaInst &AI);
4648 bool visitInstruction (Instruction &I) { return false ; }
4749 bool visitSelectInst (SelectInst &SI) { return false ; }
4850 bool visitICmpInst (ICmpInst &ICI) { return false ; }
@@ -67,9 +69,14 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
6769 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
6870};
6971
70- bool DataScalarizerVisitor::visit (Instruction &I) {
71- assert (!GlobalMap.empty ());
72- return InstVisitor::visit (I);
72+ bool DataScalarizerVisitor::visit (Function &F) {
73+ bool MadeChange = false ;
74+ ReversePostOrderTraversal<Function *> RPOT (&F);
75+ for (BasicBlock *BB : make_early_inc_range (RPOT)) {
76+ for (Instruction &I : make_early_inc_range (*BB))
77+ MadeChange |= InstVisitor::visit (I);
78+ }
79+ return MadeChange;
7380}
7481
7582GlobalVariable *
@@ -83,6 +90,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
8390 return nullptr ; // Not found
8491}
8592
93+ // Recursively creates an array version of the given vector type.
94+ static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
95+ if (auto *VecTy = dyn_cast<VectorType>(T))
96+ return ArrayType::get (VecTy->getElementType (),
97+ dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
98+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
99+ Type *NewElementType =
100+ replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
101+ return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
102+ }
103+ // If it's not a vector or array, return the original type.
104+ return T;
105+ }
106+
107+ static bool isArrayOfVectors (Type *T) {
108+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109+ return isa<VectorType>(ArrType->getElementType ());
110+ return false ;
111+ }
112+
113+ bool DataScalarizerVisitor::visitAllocaInst (AllocaInst &AI) {
114+ if (!isArrayOfVectors (AI.getAllocatedType ()))
115+ return false ;
116+
117+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
118+ IRBuilder<> Builder (&AI);
119+ LLVMContext &Ctx = AI.getContext ();
120+ Type *NewType = replaceVectorWithArray (ArrType, Ctx);
121+ AllocaInst *ArrAlloca =
122+ Builder.CreateAlloca (NewType, nullptr , AI.getName () + " .scalarize" );
123+ ArrAlloca->setAlignment (AI.getAlign ());
124+ AI.replaceAllUsesWith (ArrAlloca);
125+ AI.eraseFromParent ();
126+ return true ;
127+ }
128+
86129bool DataScalarizerVisitor::visitLoadInst (LoadInst &LI) {
87130 unsigned NumOperands = LI.getNumOperands ();
88131 for (unsigned I = 0 ; I < NumOperands; ++I) {
@@ -154,20 +197,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
154197 return true ;
155198}
156199
157- // Recursively Creates and Array like version of the given vector like type.
158- static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
159- if (auto *VecTy = dyn_cast<VectorType>(T))
160- return ArrayType::get (VecTy->getElementType (),
161- dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
162- if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
163- Type *NewElementType =
164- replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
165- return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
166- }
167- // If it's not a vector or array, return the original type.
168- return T;
169- }
170-
171200Constant *transformInitializer (Constant *Init, Type *OrigType, Type *NewType,
172201 LLVMContext &Ctx) {
173202 // Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +282,15 @@ static bool findAndReplaceVectors(Module &M) {
253282 // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
254283 // type equality. Instead we will use the visitor pattern.
255284 Impl.GlobalMap [&G] = NewGlobal;
256- for (User *U : make_early_inc_range (G.users ())) {
257- if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
258- ConstantExpr *CE = cast<ConstantExpr>(U);
259- for (User *UCE : make_early_inc_range (CE->users ())) {
260- if (Instruction *Inst = dyn_cast<Instruction>(UCE))
261- Impl.visit (*Inst);
262- }
263- }
264- if (Instruction *Inst = dyn_cast<Instruction>(U))
265- Impl.visit (*Inst);
266- }
267285 }
268286 }
269287
288+ for (auto &F : make_early_inc_range (M.functions ())) {
289+ if (F.isDeclaration ())
290+ continue ;
291+ MadeChange |= Impl.visit (F);
292+ }
293+
270294 // Remove the old globals after the iteration
271295 for (auto &[Old, New] : Impl.GlobalMap ) {
272296 Old->eraseFromParent ();
0 commit comments