Skip to content

Commit d3f1a51

Browse files
committed
Scalarize extractelement with dynamic index
1 parent 317f3bd commit d3f1a51

File tree

2 files changed

+86
-18
lines changed

2 files changed

+86
-18
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@ static const int MaxVecSize = 4;
2727

2828
using namespace llvm;
2929

30+
// Recursively creates an array-like version of a given vector type.
31+
static Type *equivalentArrayTypeFromVector(Type *T) {
32+
if (auto *VecTy = dyn_cast<VectorType>(T))
33+
return ArrayType::get(VecTy->getElementType(),
34+
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
35+
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
36+
Type *NewElementType = equivalentArrayTypeFromVector(ArrayTy->getElementType());
37+
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
38+
}
39+
// If it's not a vector or array, return the original type.
40+
return T;
41+
}
42+
3043
class DXILDataScalarizationLegacy : public ModulePass {
3144

3245
public:
@@ -55,7 +68,7 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
5568
bool visitCastInst(CastInst &CI) { return false; }
5669
bool visitBitCastInst(BitCastInst &BCI) { return false; }
5770
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
58-
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
71+
bool visitExtractElementInst(ExtractElementInst &EEI);
5972
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
6073
bool visitPHINode(PHINode &PHI) { return false; }
6174
bool visitLoadInst(LoadInst &LI);
@@ -90,20 +103,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
90103
return nullptr; // Not found
91104
}
92105

93-
// Recursively creates an array version of the given vector type.
94-
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
95-
if (auto *VecTy = dyn_cast<VectorType>(T))
96-
return ArrayType::get(VecTy->getElementType(),
97-
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
98-
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
99-
Type *NewElementType =
100-
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
101-
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
102-
}
103-
// If it's not a vector or array, return the original type.
104-
return T;
105-
}
106-
107106
static bool isArrayOfVectors(Type *T) {
108107
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
109108
return isa<VectorType>(ArrType->getElementType());
@@ -116,8 +115,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
116115

117116
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
118117
IRBuilder<> Builder(&AI);
119-
LLVMContext &Ctx = AI.getContext();
120-
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
118+
Type *NewType = equivalentArrayTypeFromVector(ArrType);
121119
AllocaInst *ArrAlloca =
122120
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
123121
ArrAlloca->setAlignment(AI.getAlign());
@@ -173,6 +171,38 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
173171
return false;
174172
}
175173

174+
bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
175+
// If the index is a constant then we don't need to scalarize it
176+
Value *Index = EEI.getIndexOperand();
177+
Type *IndexTy = Index->getType();
178+
if (isa<ConstantInt>(Index))
179+
return false;
180+
181+
IRBuilder<> Builder(&EEI);
182+
VectorType *VecTy = EEI.getVectorOperandType();
183+
assert(VecTy->getElementCount().isFixed() &&
184+
"Vector operand of ExtractElement must have a fixed size");
185+
186+
Type *ArrTy = equivalentArrayTypeFromVector(VecTy);
187+
Value *ArrAlloca = Builder.CreateAlloca(ArrTy);
188+
189+
for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
190+
Value *EE = Builder.CreateExtractElement(EEI.getVectorOperand(), I);
191+
Value *GEP = Builder.CreateInBoundsGEP(
192+
ArrTy, ArrAlloca,
193+
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, I)});
194+
Builder.CreateStore(EE, GEP);
195+
}
196+
197+
Value *GEP = Builder.CreateInBoundsGEP(ArrTy, ArrAlloca,
198+
{ConstantInt::get(IndexTy, 0), Index});
199+
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), GEP);
200+
201+
EEI.replaceAllUsesWith(Load);
202+
EEI.eraseFromParent();
203+
return true;
204+
}
205+
176206
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
177207

178208
unsigned NumOperands = GEPI.getNumOperands();
@@ -257,7 +287,7 @@ static bool findAndReplaceVectors(Module &M) {
257287
for (GlobalVariable &G : M.globals()) {
258288
Type *OrigType = G.getValueType();
259289

260-
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
290+
Type *NewType = equivalentArrayTypeFromVector(OrigType);
261291
if (OrigType != NewType) {
262292
// Create a new global variable with the updated type
263293
// Note: Initializer is set via transformInitializer
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
define float @extract_float_vec_dynamic(<4 x float> %0, i32 %1) {
5+
; CHECK-LABEL: define float @extract_float_vec_dynamic(
6+
; CHECK-SAME: <4 x float> [[TMP0:%.*]], i32 [[TMP1:%.*]]) {
7+
; CHECK-NEXT: [[TMP3:%.*]] = alloca [4 x float], align 4
8+
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP0]], i64 0
9+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 0
10+
; CHECK-NEXT: store float [[TMP4]], ptr [[TMP5]], align 4
11+
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[TMP0]], i64 1
12+
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 1
13+
; CHECK-NEXT: store float [[TMP6]], ptr [[TMP7]], align 4
14+
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x float> [[TMP0]], i64 2
15+
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 2
16+
; CHECK-NEXT: store float [[TMP8]], ptr [[TMP9]], align 4
17+
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x float> [[TMP0]], i64 3
18+
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 3
19+
; CHECK-NEXT: store float [[TMP10]], ptr [[TMP11]], align 4
20+
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP3]], i32 0, i32 [[TMP1]]
21+
; CHECK-NEXT: [[TMP13:%.*]] = load float, ptr [[TMP12]], align 4
22+
; CHECK-NEXT: ret float [[TMP13]]
23+
;
24+
%e = extractelement <4 x float> %0, i32 %1
25+
ret float %e
26+
}
27+
28+
; An extractelement with a constant index should not be converted to array form
29+
define i16 @extract_i16_vec_constant(<4 x i16> %0) {
30+
; CHECK-LABEL: define i16 @extract_i16_vec_constant(
31+
; CHECK-SAME: <4 x i16> [[TMP0:%.*]]) {
32+
; CHECK-NEXT: [[E:%.*]] = extractelement <4 x i16> [[TMP0]], i32 1
33+
; CHECK-NEXT: ret i16 [[E]]
34+
;
35+
%e = extractelement <4 x i16> %0, i32 1
36+
ret i16 %e
37+
}
38+

0 commit comments

Comments
 (0)