1010#include " DirectX.h"
1111#include " llvm/ADT/PostOrderIterator.h"
1212#include " llvm/ADT/STLExtras.h"
13+ #include " llvm/IR/DerivedTypes.h"
1314#include " llvm/IR/GlobalVariable.h"
1415#include " llvm/IR/IRBuilder.h"
1516#include " llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
4041class DataScalarizerVisitor : public InstVisitor <DataScalarizerVisitor, bool > {
4142public:
4243 DataScalarizerVisitor () : GlobalMap() {}
43- bool visit (Instruction &I );
44+ bool visit (Function &F );
4445 // InstVisitor methods. They return true if the instruction was scalarized,
4546 // false if nothing changed.
47+ bool visitAllocaInst (AllocaInst &AI);
4648 bool visitInstruction (Instruction &I) { return false ; }
4749 bool visitSelectInst (SelectInst &SI) { return false ; }
4850 bool visitICmpInst (ICmpInst &ICI) { return false ; }
@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
6567private:
6668 GlobalVariable *lookupReplacementGlobal (Value *CurrOperand);
6769 DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
70+ static bool isArrayOfVectors (Type *T);
6871};
6972
70- bool DataScalarizerVisitor::visit (Instruction &I) {
71- assert (!GlobalMap.empty ());
72- return InstVisitor::visit (I);
73+ bool DataScalarizerVisitor::visit (Function &F) {
74+ bool MadeChange = false ;
75+ ReversePostOrderTraversal<Function *> RPOT (&F);
76+ for (BasicBlock *BB : make_early_inc_range (RPOT)) {
77+ for (Instruction &I : make_early_inc_range (*BB))
78+ MadeChange |= InstVisitor::visit (I);
79+ }
80+ return MadeChange;
7381}
7482
7583GlobalVariable *
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
8391 return nullptr ; // Not found
8492}
8593
94+ // Recursively Creates and Array like version of the given vector like type.
95+ static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
96+ if (auto *VecTy = dyn_cast<VectorType>(T))
97+ return ArrayType::get (VecTy->getElementType (),
98+ dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
99+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
100+ Type *NewElementType =
101+ replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
102+ return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
103+ }
104+ // If it's not a vector or array, return the original type.
105+ return T;
106+ }
107+
108+ bool DataScalarizerVisitor::isArrayOfVectors (Type *T) {
109+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
110+ return isa<VectorType>(ArrType->getElementType ());
111+ return false ;
112+ }
113+
114+ bool DataScalarizerVisitor::visitAllocaInst (AllocaInst &AI) {
115+ if (!isArrayOfVectors (AI.getAllocatedType ()))
116+ return false ;
117+
118+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
119+ IRBuilder<> Builder (&AI);
120+ LLVMContext &Ctx = AI.getContext ();
121+ Type *NewType = replaceVectorWithArray (ArrType, Ctx);
122+ AllocaInst *ArrAlloca =
123+ Builder.CreateAlloca (NewType, nullptr , AI.getName () + " .scalarize" );
124+ ArrAlloca->setAlignment (AI.getAlign ());
125+ AI.replaceAllUsesWith (ArrAlloca);
126+ AI.eraseFromParent ();
127+ return true ;
128+ }
129+
86130bool DataScalarizerVisitor::visitLoadInst (LoadInst &LI) {
87131 unsigned NumOperands = LI.getNumOperands ();
88132 for (unsigned I = 0 ; I < NumOperands; ++I) {
@@ -154,20 +198,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
154198 return true ;
155199}
156200
157- // Recursively Creates and Array like version of the given vector like type.
158- static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
159- if (auto *VecTy = dyn_cast<VectorType>(T))
160- return ArrayType::get (VecTy->getElementType (),
161- dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
162- if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
163- Type *NewElementType =
164- replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
165- return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
166- }
167- // If it's not a vector or array, return the original type.
168- return T;
169- }
170-
171201Constant *transformInitializer (Constant *Init, Type *OrigType, Type *NewType,
172202 LLVMContext &Ctx) {
173203 // Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +283,15 @@ static bool findAndReplaceVectors(Module &M) {
253283 // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
254284 // type equality. Instead we will use the visitor pattern.
255285 Impl.GlobalMap [&G] = NewGlobal;
256- for (User *U : make_early_inc_range (G.users ())) {
257- if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
258- ConstantExpr *CE = cast<ConstantExpr>(U);
259- for (User *UCE : make_early_inc_range (CE->users ())) {
260- if (Instruction *Inst = dyn_cast<Instruction>(UCE))
261- Impl.visit (*Inst);
262- }
263- }
264- if (Instruction *Inst = dyn_cast<Instruction>(U))
265- Impl.visit (*Inst);
266- }
267286 }
268287 }
269288
289+ for (auto &F : make_early_inc_range (M.functions ())) {
290+ if (F.isDeclaration ())
291+ continue ;
292+ MadeChange |= Impl.visit (F);
293+ }
294+
270295 // Remove the old globals after the iteration
271296 for (auto &[Old, New] : Impl.GlobalMap ) {
272297 Old->eraseFromParent ();
0 commit comments