@@ -27,6 +27,19 @@ static const int MaxVecSize = 4;
2727
2828using  namespace  llvm ; 
2929
30+ //  Recursively creates an array-like version of a given vector type.
31+ static  Type *equivalentArrayTypeFromVector (Type *T) {
32+   if  (auto  *VecTy = dyn_cast<VectorType>(T))
33+     return  ArrayType::get (VecTy->getElementType (),
34+                           dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
35+   if  (auto  *ArrayTy = dyn_cast<ArrayType>(T)) {
36+     Type *NewElementType = equivalentArrayTypeFromVector (ArrayTy->getElementType ());
37+     return  ArrayType::get (NewElementType, ArrayTy->getNumElements ());
38+   }
39+   //  If it's not a vector or array, return the original type.
40+   return  T;
41+ }
42+ 
3043class  DXILDataScalarizationLegacy  : public  ModulePass  {
3144
3245public: 
@@ -55,7 +68,7 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
5568  bool  visitCastInst (CastInst &CI) { return  false ; }
5669  bool  visitBitCastInst (BitCastInst &BCI) { return  false ; }
5770  bool  visitInsertElementInst (InsertElementInst &IEI) { return  false ; }
58-   bool  visitExtractElementInst (ExtractElementInst &EEI) {  return   false ; } 
71+   bool  visitExtractElementInst (ExtractElementInst &EEI); 
5972  bool  visitShuffleVectorInst (ShuffleVectorInst &SVI) { return  false ; }
6073  bool  visitPHINode (PHINode &PHI) { return  false ; }
6174  bool  visitLoadInst (LoadInst &LI);
@@ -90,20 +103,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
90103  return  nullptr ; //  Not found
91104}
92105
93- //  Recursively creates an array version of the given vector type.
94- static  Type *replaceVectorWithArray (Type *T, LLVMContext &Ctx) {
95-   if  (auto  *VecTy = dyn_cast<VectorType>(T))
96-     return  ArrayType::get (VecTy->getElementType (),
97-                           dyn_cast<FixedVectorType>(VecTy)->getNumElements ());
98-   if  (auto  *ArrayTy = dyn_cast<ArrayType>(T)) {
99-     Type *NewElementType =
100-         replaceVectorWithArray (ArrayTy->getElementType (), Ctx);
101-     return  ArrayType::get (NewElementType, ArrayTy->getNumElements ());
102-   }
103-   //  If it's not a vector or array, return the original type.
104-   return  T;
105- }
106- 
107106static  bool  isArrayOfVectors (Type *T) {
108107  if  (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109108    return  isa<VectorType>(ArrType->getElementType ());
@@ -116,8 +115,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
116115
117116  ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType ());
118117  IRBuilder<> Builder (&AI);
119-   LLVMContext &Ctx = AI.getContext ();
120-   Type *NewType = replaceVectorWithArray (ArrType, Ctx);
118+   Type *NewType = equivalentArrayTypeFromVector (ArrType);
121119  AllocaInst *ArrAlloca =
122120      Builder.CreateAlloca (NewType, nullptr , AI.getName () + " .scalarize" 
123121  ArrAlloca->setAlignment (AI.getAlign ());
@@ -173,6 +171,38 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
173171  return  false ;
174172}
175173
174+ bool  DataScalarizerVisitor::visitExtractElementInst (ExtractElementInst &EEI) {
175+   //  If the index is a constant then we don't need to scalarize it
176+   Value *Index = EEI.getIndexOperand ();
177+   Type *IndexTy = Index->getType ();
178+   if  (isa<ConstantInt>(Index))
179+     return  false ;
180+ 
181+   IRBuilder<> Builder (&EEI);
182+   VectorType *VecTy = EEI.getVectorOperandType ();
183+   assert (VecTy->getElementCount ().isFixed () &&
184+          " Vector operand of ExtractElement must have a fixed size" 
185+   
186+   Type *ArrTy = equivalentArrayTypeFromVector (VecTy);
187+   Value *ArrAlloca = Builder.CreateAlloca (ArrTy);
188+ 
189+   for  (unsigned  I = 0 ; I < ArrTy->getArrayNumElements (); ++I) {
190+     Value *EE = Builder.CreateExtractElement (EEI.getVectorOperand (), I);
191+     Value *GEP = Builder.CreateInBoundsGEP (
192+         ArrTy, ArrAlloca,
193+         {ConstantInt::get (IndexTy, 0 ), ConstantInt::get (IndexTy, I)});
194+     Builder.CreateStore (EE, GEP);
195+   }
196+ 
197+   Value *GEP = Builder.CreateInBoundsGEP (ArrTy, ArrAlloca,
198+                                          {ConstantInt::get (IndexTy, 0 ), Index});
199+   Value *Load = Builder.CreateLoad (ArrTy->getArrayElementType (), GEP);
200+ 
201+   EEI.replaceAllUsesWith (Load);
202+   EEI.eraseFromParent ();
203+   return  true ;
204+ }
205+ 
176206bool  DataScalarizerVisitor::visitGetElementPtrInst (GetElementPtrInst &GEPI) {
177207
178208  unsigned  NumOperands = GEPI.getNumOperands ();
@@ -257,7 +287,7 @@ static bool findAndReplaceVectors(Module &M) {
257287  for  (GlobalVariable &G : M.globals ()) {
258288    Type *OrigType = G.getValueType ();
259289
260-     Type *NewType = replaceVectorWithArray (OrigType, Ctx );
290+     Type *NewType = equivalentArrayTypeFromVector (OrigType);
261291    if  (OrigType != NewType) {
262292      //  Create a new global variable with the updated type
263293      //  Note: Initializer is set via transformInitializer
0 commit comments