Skip to content
117 changes: 98 additions & 19 deletions llvm/lib/Target/DirectX/DXILDataScalarization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ static const int MaxVecSize = 4;

using namespace llvm;

// Recursively creates an array-like version of a given vector type.
static Type *equivalentArrayTypeFromVector(Type *T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine I would have just made a function declaration at the top so the implementation could live anywhere.

if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
equivalentArrayTypeFromVector(ArrayTy->getElementType());
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

class DXILDataScalarizationLegacy : public ModulePass {

public:
Expand Down Expand Up @@ -54,8 +68,8 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
bool visitCastInst(CastInst &CI) { return false; }
bool visitBitCastInst(BitCastInst &BCI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI);
bool visitExtractElementInst(ExtractElementInst &EEI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
bool visitPHINode(PHINode &PHI) { return false; }
bool visitLoadInst(LoadInst &LI);
Expand Down Expand Up @@ -90,20 +104,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}

// Recursively creates an array version of the given vector type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

static bool isArrayOfVectors(Type *T) {
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
return isa<VectorType>(ArrType->getElementType());
Expand All @@ -116,8 +116,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {

ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
IRBuilder<> Builder(&AI);
LLVMContext &Ctx = AI.getContext();
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(ArrType);
AllocaInst *ArrAlloca =
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
ArrAlloca->setAlignment(AI.getAlign());
Expand Down Expand Up @@ -173,6 +172,86 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
return false;
}

static bool replaceDynamicInsertElementInst(InsertElementInst &IEI) {
IRBuilder<> Builder(&IEI);

Value *Vec = IEI.getOperand(0);
Value *Val = IEI.getOperand(1);
Value *Index = IEI.getOperand(2);
Type *IndexTy = Index->getType();

Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
Value *ArrAlloca = Builder.CreateAlloca(ArrTy);
const uint64_t ArrNumElems = ArrTy->getArrayNumElements();

SmallVector<Value *, 4> GEPs(ArrNumElems);
for (unsigned I = 0; I < ArrNumElems; ++I) {
Value *EE = Builder.CreateExtractElement(Vec, I);
Value *GEP = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca,
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, I)});
Builder.CreateStore(EE, GEP);
GEPs[I] = GEP;
}

Value *GEPForStore = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca, {ConstantInt::get(IndexTy, 0), Index});
Builder.CreateStore(Val, GEPForStore);

Value *NewIEI = PoisonValue::get(Vec->getType());
for (unsigned I = 0; I < ArrNumElems; ++I) {
Value *GEP = GEPs[I];
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), GEP);
NewIEI =
Builder.CreateInsertElement(NewIEI, Load, ConstantInt::get(IndexTy, I));
}

IEI.replaceAllUsesWith(NewIEI);
IEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = IEI.getOperand(2);
if (isa<ConstantInt>(Index))
return false;
return replaceDynamicInsertElementInst(IEI);
}

static bool replaceDynamicExtractElementInst(ExtractElementInst &EEI) {
IRBuilder<> Builder(&EEI);

Value *Index = EEI.getIndexOperand();
Type *IndexTy = Index->getType();

Type *ArrTy = equivalentArrayTypeFromVector(EEI.getVectorOperandType());
Value *ArrAlloca = Builder.CreateAlloca(ArrTy);
for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
Value *EE = Builder.CreateExtractElement(EEI.getVectorOperand(), I);
Value *GEP = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca,
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, I)});
Builder.CreateStore(EE, GEP);
}

Value *GEP = Builder.CreateInBoundsGEP(ArrTy, ArrAlloca,
{ConstantInt::get(IndexTy, 0), Index});
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), GEP);

EEI.replaceAllUsesWith(Load);
EEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = EEI.getIndexOperand();
if (isa<ConstantInt>(Index))
return false;
return replaceDynamicExtractElementInst(EEI);
}

bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {

unsigned NumOperands = GEPI.getNumOperands();
Expand Down Expand Up @@ -257,7 +336,7 @@ static bool findAndReplaceVectors(Module &M) {
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();

Type *NewType = replaceVectorWithArray(OrigType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(OrigType);
if (OrigType != NewType) {
// Create a new global variable with the updated type
// Note: Initializer is set via transformInitializer
Expand Down
76 changes: 76 additions & 0 deletions llvm/test/CodeGen/DirectX/scalarize-dynamic-vector-index.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define float @extract_float_vec_dynamic(<4 x float> %v, i32 %i) {
; CHECK-LABEL: define float @extract_float_vec_dynamic(
; CHECK-SAME: <4 x float> [[V:%.*]], i32 [[I:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [4 x float], align 4
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[V]], i64 0
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP1]], i32 0, i32 0
; CHECK-NEXT: store float [[TMP2]], ptr [[TMP3]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[V]], i64 1
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP1]], i32 0, i32 1
; CHECK-NEXT: store float [[TMP4]], ptr [[TMP5]], align 4
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[V]], i64 2
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP1]], i32 0, i32 2
; CHECK-NEXT: store float [[TMP6]], ptr [[TMP7]], align 4
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x float> [[V]], i64 3
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP1]], i32 0, i32 3
; CHECK-NEXT: store float [[TMP8]], ptr [[TMP9]], align 4
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [4 x float], ptr [[TMP1]], i32 0, i32 [[I]]
; CHECK-NEXT: [[TMP11:%.*]] = load float, ptr [[TMP10]], align 4
; CHECK-NEXT: ret float [[TMP11]]
;
%ee = extractelement <4 x float> %v, i32 %i
ret float %ee
}

define <3 x i32> @insert_i32_vec_dynamic(<3 x i32> %v, i32 %a, i32 %i) {
; CHECK-LABEL: define <3 x i32> @insert_i32_vec_dynamic(
; CHECK-SAME: <3 x i32> [[V:%.*]], i32 [[A:%.*]], i32 [[I:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <3 x i32> [[V]], i64 0
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [3 x i32], ptr [[TMP1]], i32 0, i32 0
; CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP3]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x i32> [[V]], i64 1
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [3 x i32], ptr [[TMP1]], i32 0, i32 1
; CHECK-NEXT: store i32 [[TMP4]], ptr [[TMP5]], align 4
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <3 x i32> [[V]], i64 2
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [3 x i32], ptr [[TMP1]], i32 0, i32 2
; CHECK-NEXT: store i32 [[TMP6]], ptr [[TMP7]], align 4
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [3 x i32], ptr [[TMP1]], i32 0, i32 [[I]]
; CHECK-NEXT: store i32 [[A]], ptr [[TMP8]], align 4
; CHECK-NEXT: [[TMP9:%.*]] = load i32, ptr [[TMP3]], align 4
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <3 x i32> poison, i32 [[TMP9]], i32 0
; CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr [[TMP5]], align 4
; CHECK-NEXT: [[TMP12:%.*]] = insertelement <3 x i32> [[TMP10]], i32 [[TMP11]], i32 1
; CHECK-NEXT: [[TMP13:%.*]] = load i32, ptr [[TMP7]], align 4
; CHECK-NEXT: [[TMP14:%.*]] = insertelement <3 x i32> [[TMP12]], i32 [[TMP13]], i32 2
; CHECK-NEXT: ret <3 x i32> [[TMP14]]
;
%ie = insertelement <3 x i32> %v, i32 %a, i32 %i
ret <3 x i32> %ie
}

; An extractelement with a constant index should not be converted to array form
define i16 @extract_i16_vec_constant(<4 x i16> %v) {
; CHECK-LABEL: define i16 @extract_i16_vec_constant(
; CHECK-SAME: <4 x i16> [[V:%.*]]) {
; CHECK-NEXT: [[EE:%.*]] = extractelement <4 x i16> [[V]], i32 1
; CHECK-NEXT: ret i16 [[EE]]
;
%ee = extractelement <4 x i16> %v, i32 1
ret i16 %ee
}

; An insertelement with a constant index should not be converted to array form
define <2 x half> @insert_half_vec_constant(<2 x half> %v, half %a) {
; CHECK-LABEL: define <2 x half> @insert_half_vec_constant(
; CHECK-SAME: <2 x half> [[V:%.*]], half [[A:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x half> [[V]], half [[A]], i32 1
; CHECK-NEXT: ret <2 x half> [[TMP1]]
;
%ie = insertelement <2 x half> %v, half %a, i32 1
ret <2 x half> %ie
}

Loading