@@ -27,6 +27,19 @@ static const int MaxVecSize = 4;
27
27
28
28
using namespace llvm ;
29
29
30
+ // Recursively creates an array-like version of a given vector type.
31
+ static Type *equivalentArrayTypeFromVector (Type *T) {
32
+ if (auto *VecTy = dyn_cast<VectorType>(T))
33
+ return ArrayType::get (VecTy->getElementType (),
34
+ dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
35
+ if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
36
+ Type *NewElementType = equivalentArrayTypeFromVector (ArrayTy->getElementType ());
37
+ return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
38
+ }
39
+ // If it's not a vector or array, return the original type.
40
+ return T;
41
+ }
42
+
30
43
class DXILDataScalarizationLegacy : public ModulePass {
31
44
32
45
public:
@@ -55,7 +68,7 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
55
68
bool visitCastInst (CastInst &CI) { return false ; }
56
69
bool visitBitCastInst (BitCastInst &BCI) { return false ; }
57
70
bool visitInsertElementInst (InsertElementInst &IEI) { return false ; }
58
- bool visitExtractElementInst (ExtractElementInst &EEI) { return false ; }
71
+ bool visitExtractElementInst (ExtractElementInst &EEI);
59
72
bool visitShuffleVectorInst (ShuffleVectorInst &SVI) { return false ; }
60
73
bool visitPHINode (PHINode &PHI) { return false ; }
61
74
bool visitLoadInst (LoadInst &LI);
@@ -90,20 +103,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
90
103
return nullptr ; // Not found
91
104
}
92
105
93
- // Recursively creates an array version of the given vector type.
94
- static Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
95
- if (auto *VecTy = dyn_cast<VectorType>(T))
96
- return ArrayType::get (VecTy->getElementType (),
97
- dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
98
- if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
99
- Type *NewElementType =
100
- replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
101
- return ArrayType::get (NewElementType, ArrayTy->getNumElements ());
102
- }
103
- // If it's not a vector or array, return the original type.
104
- return T;
105
- }
106
-
107
106
static bool isArrayOfVectors (Type *T) {
108
107
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109
108
return isa<VectorType>(ArrType->getElementType ());
@@ -116,8 +115,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
116
115
117
116
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
118
117
IRBuilder<> Builder (&AI);
119
- LLVMContext &Ctx = AI.getContext ();
120
- Type *NewType = replaceVectorWithArray (ArrType, Ctx);
118
+ Type *NewType = equivalentArrayTypeFromVector (ArrType);
121
119
AllocaInst *ArrAlloca =
122
120
Builder.CreateAlloca (NewType, nullptr , AI.getName () + " .scalarize" );
123
121
ArrAlloca->setAlignment (AI.getAlign ());
@@ -173,6 +171,38 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
173
171
return false ;
174
172
}
175
173
174
+ bool DataScalarizerVisitor::visitExtractElementInst (ExtractElementInst &EEI) {
175
+ // If the index is a constant then we don't need to scalarize it
176
+ Value *Index = EEI.getIndexOperand ();
177
+ Type *IndexTy = Index->getType ();
178
+ if (isa<ConstantInt>(Index))
179
+ return false ;
180
+
181
+ IRBuilder<> Builder (&EEI);
182
+ VectorType *VecTy = EEI.getVectorOperandType ();
183
+ assert (VecTy->getElementCount ().isFixed () &&
184
+ " Vector operand of ExtractElement must have a fixed size" );
185
+
186
+ Type *ArrTy = equivalentArrayTypeFromVector (VecTy);
187
+ Value *ArrAlloca = Builder.CreateAlloca (ArrTy);
188
+
189
+ for (unsigned I = 0 ; I < ArrTy->getArrayNumElements (); ++I) {
190
+ Value *EE = Builder.CreateExtractElement (EEI.getVectorOperand (), I);
191
+ Value *GEP = Builder.CreateInBoundsGEP (
192
+ ArrTy, ArrAlloca,
193
+ {ConstantInt::get (IndexTy, 0 ), ConstantInt::get (IndexTy, I)});
194
+ Builder.CreateStore (EE, GEP);
195
+ }
196
+
197
+ Value *GEP = Builder.CreateInBoundsGEP (ArrTy, ArrAlloca,
198
+ {ConstantInt::get (IndexTy, 0 ), Index});
199
+ Value *Load = Builder.CreateLoad (ArrTy->getArrayElementType (), GEP);
200
+
201
+ EEI.replaceAllUsesWith (Load);
202
+ EEI.eraseFromParent ();
203
+ return true ;
204
+ }
205
+
176
206
bool DataScalarizerVisitor::visitGetElementPtrInst (GetElementPtrInst &GEPI) {
177
207
178
208
unsigned NumOperands = GEPI.getNumOperands ();
@@ -257,7 +287,7 @@ static bool findAndReplaceVectors(Module &M) {
257
287
for (GlobalVariable &G : M.globals ()) {
258
288
Type *OrigType = G.getValueType ();
259
289
260
- Type *NewType = replaceVectorWithArray (OrigType, Ctx );
290
+ Type *NewType = equivalentArrayTypeFromVector (OrigType);
261
291
if (OrigType != NewType) {
262
292
// Create a new global variable with the updated type
263
293
// Note: Initializer is set via transformInitializer
0 commit comments