10
10
#include " DirectX.h"
11
11
#include " llvm/ADT/PostOrderIterator.h"
12
12
#include " llvm/ADT/STLExtras.h"
13
+ #include " llvm/IR/DerivedTypes.h"
13
14
#include " llvm/IR/GlobalVariable.h"
14
15
#include " llvm/IR/IRBuilder.h"
15
16
#include " llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
40
41
class DataScalarizerVisitor : public InstVisitor <DataScalarizerVisitor, bool > {
41
42
public:
42
43
DataScalarizerVisitor () : GlobalMap() {}
43
- bool visit (Instruction &I );
44
+ bool visit (Function &F );
44
45
// InstVisitor methods. They return true if the instruction was scalarized,
45
46
// false if nothing changed.
47
+ bool visitAllocaInst (AllocaInst &AI);
46
48
bool visitInstruction (Instruction &I) { return false ; }
47
49
bool visitSelectInst (SelectInst &SI) { return false ; }
48
50
bool visitICmpInst (ICmpInst &ICI) { return false ; }
@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
65
67
private:
66
68
GlobalVariable *lookupReplacementGlobal (Value *CurrOperand);
67
69
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
70
+ static bool isArrayOfVectors (Type *T);
68
71
};
69
72
70
- bool DataScalarizerVisitor::visit (Instruction &I) {
71
- assert (!GlobalMap.empty ());
72
- return InstVisitor::visit (I);
73
+ bool DataScalarizerVisitor::visit (Function &F) {
74
+ bool MadeChange = false ;
75
+ ReversePostOrderTraversal<Function *> RPOT (&F);
76
+ for (BasicBlock *BB : make_early_inc_range (RPOT)) {
77
+ for (Instruction &I : make_early_inc_range (*BB))
78
+ MadeChange |= InstVisitor::visit (I);
79
+ }
80
+ return MadeChange;
73
81
}
74
82
75
83
GlobalVariable *
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
83
91
return nullptr ; // Not found
84
92
}
85
93
94
+ // Recursively Creates and Array like version of the given vector like type.
95
+ static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
96
+ if (auto *VecTy = dyn_cast<VectorType>(T))
97
+ return ArrayType::get (VecTy->getElementType (),
98
+ dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
99
+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
100
+ Type *NewElementType =
101
+ replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
102
+ return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
103
+ }
104
+ // If it's not a vector or array, return the original type.
105
+ return T;
106
+ }
107
+
108
+ bool DataScalarizerVisitor::isArrayOfVectors (Type *T) {
109
+ if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
110
+ return isa<VectorType>(ArrType->getElementType ());
111
+ return false ;
112
+ }
113
+
114
+ bool DataScalarizerVisitor::visitAllocaInst (AllocaInst &AI) {
115
+ if (!isArrayOfVectors (AI.getAllocatedType ()))
116
+ return false ;
117
+
118
+ ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
119
+ IRBuilder<> Builder (&AI);
120
+ LLVMContext &Ctx = AI.getContext ();
121
+ Type *NewType = replaceVectorWithArray (ArrType, Ctx);
122
+ AllocaInst *ArrAlloca =
123
+ Builder.CreateAlloca (NewType, nullptr , AI.getName () + " .scalarize" );
124
+ ArrAlloca->setAlignment (AI.getAlign ());
125
+ AI.replaceAllUsesWith (ArrAlloca);
126
+ AI.eraseFromParent ();
127
+ return true ;
128
+ }
129
+
86
130
bool DataScalarizerVisitor::visitLoadInst (LoadInst &LI) {
87
131
unsigned NumOperands = LI.getNumOperands ();
88
132
for (unsigned I = 0 ; I < NumOperands; ++I) {
@@ -154,20 +198,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
154
198
return true ;
155
199
}
156
200
157
- // Recursively Creates and Array like version of the given vector like type.
158
- static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
159
- if (auto *VecTy = dyn_cast<VectorType>(T))
160
- return ArrayType::get (VecTy->getElementType (),
161
- dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
162
- if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
163
- Type *NewElementType =
164
- replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
165
- return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
166
- }
167
- // If it's not a vector or array, return the original type.
168
- return T;
169
- }
170
-
171
201
Constant *transformInitializer (Constant *Init, Type *OrigType, Type *NewType,
172
202
LLVMContext &Ctx) {
173
203
// Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +283,15 @@ static bool findAndReplaceVectors(Module &M) {
253
283
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
254
284
// type equality. Instead we will use the visitor pattern.
255
285
Impl.GlobalMap [&G] = NewGlobal;
256
- for (User *U : make_early_inc_range (G.users ())) {
257
- if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
258
- ConstantExpr *CE = cast<ConstantExpr>(U);
259
- for (User *UCE : make_early_inc_range (CE->users ())) {
260
- if (Instruction *Inst = dyn_cast<Instruction>(UCE))
261
- Impl.visit (*Inst);
262
- }
263
- }
264
- if (Instruction *Inst = dyn_cast<Instruction>(U))
265
- Impl.visit (*Inst);
266
- }
267
286
}
268
287
}
269
288
289
+ for (auto &F : make_early_inc_range (M.functions ())) {
290
+ if (F.isDeclaration ())
291
+ continue ;
292
+ MadeChange |= Impl.visit (F);
293
+ }
294
+
270
295
// Remove the old globals after the iteration
271
296
for (auto &[Old, New] : Impl.GlobalMap ) {
272
297
Old->eraseFromParent ();
0 commit comments