Skip to content

[DirectX] Scalarize extractelement and insertelement with dynamic indices #141676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 20, 2025
150 changes: 129 additions & 21 deletions llvm/lib/Target/DirectX/DXILDataScalarization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ static const int MaxVecSize = 4;

using namespace llvm;

// Recursively creates an array-like version of a given vector type.
static Type *equivalentArrayTypeFromVector(Type *T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine I would have just made a function declaration at the top so the implementation could live anywhere.

if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
equivalentArrayTypeFromVector(ArrayTy->getElementType());
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

class DXILDataScalarizationLegacy : public ModulePass {

public:
Expand Down Expand Up @@ -54,8 +68,8 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
bool visitCastInst(CastInst &CI) { return false; }
bool visitBitCastInst(BitCastInst &BCI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
bool visitInsertElementInst(InsertElementInst &IEI);
bool visitExtractElementInst(ExtractElementInst &EEI);
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
bool visitPHINode(PHINode &PHI) { return false; }
bool visitLoadInst(LoadInst &LI);
Expand All @@ -65,6 +79,16 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
friend bool findAndReplaceVectors(llvm::Module &M);

private:
typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
typedef SmallDenseMap<Value *, AllocaAndGEPs>
VectorToArrayMap; // A map from a vector-typed Value to its corresponding
// AllocaInst and GEPs to each element of an array
VectorToArrayMap VectorAllocaMap;
AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
const Twine &Name);
bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);

GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
};
Expand All @@ -76,6 +100,7 @@ bool DataScalarizerVisitor::visit(Function &F) {
for (Instruction &I : make_early_inc_range(*BB))
MadeChange |= InstVisitor::visit(I);
}
VectorAllocaMap.clear();
return MadeChange;
}

Expand All @@ -90,20 +115,6 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
return nullptr; // Not found
}

// Recursively creates an array version of the given vector type.
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
if (auto *VecTy = dyn_cast<VectorType>(T))
return ArrayType::get(VecTy->getElementType(),
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
Type *NewElementType =
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
}
// If it's not a vector or array, return the original type.
return T;
}

static bool isArrayOfVectors(Type *T) {
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
return isa<VectorType>(ArrType->getElementType());
Expand All @@ -116,8 +127,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {

ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
IRBuilder<> Builder(&AI);
LLVMContext &Ctx = AI.getContext();
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(ArrType);
AllocaInst *ArrAlloca =
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
ArrAlloca->setAlignment(AI.getAlign());
Expand Down Expand Up @@ -173,6 +183,104 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
return false;
}

DataScalarizerVisitor::AllocaAndGEPs
DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
const Twine &Name = "") {
// If there is already an alloca for this vector, return it
auto VA = VectorAllocaMap.find(Vec);
if (VA != VectorAllocaMap.end())
return VA->second;

auto InsertPoint = Builder.GetInsertPoint();
Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());

Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
AllocaInst *ArrAlloca =
Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
const uint64_t ArrNumElems = ArrTy->getArrayNumElements();

SmallVector<Value *, 4> GEPs(ArrNumElems);
for (unsigned I = 0; I < ArrNumElems; ++I) {
Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
GEPs[I] = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
Name + ".index");
Builder.CreateStore(EE, GEPs[I]);
}

VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
Builder.SetInsertPoint(InsertPoint);
return {ArrAlloca, GEPs};
}

bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
InsertElementInst &IEI) {
IRBuilder<> Builder(&IEI);

Value *Vec = IEI.getOperand(0);
Value *Val = IEI.getOperand(1);
Value *Index = IEI.getOperand(2);

AllocaAndGEPs ArrAllocaAndGEPs =
createArrayFromVector(Builder, Vec, IEI.getName());
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;

Type *ArrTy = ArrAlloca->getAllocatedType();
Value *GEPForStore =
Builder.CreateInBoundsGEP(ArrTy, ArrAlloca, {Builder.getInt32(0), Index},
IEI.getName() + ".dynindex");
Builder.CreateStore(Val, GEPForStore);

Value *NewIEI = PoisonValue::get(Vec->getType());
for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
IEI.getName() + ".load");
NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
IEI.getName() + ".insert");
}

IEI.replaceAllUsesWith(NewIEI);
IEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = IEI.getOperand(2);
if (isa<ConstantInt>(Index))
return false;
return replaceDynamicInsertElementInst(IEI);
}

bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
ExtractElementInst &EEI) {
IRBuilder<> Builder(&EEI);

AllocaAndGEPs ArrAllocaAndGEPs =
createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;

Type *ArrTy = ArrAlloca->getAllocatedType();
Value *GEP = Builder.CreateInBoundsGEP(
ArrTy, ArrAlloca, {Builder.getInt32(0), EEI.getIndexOperand()},
EEI.getName() + ".index");
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), GEP,
EEI.getName() + ".load");

EEI.replaceAllUsesWith(Load);
EEI.eraseFromParent();
return true;
}

bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
// If the index is a constant then we don't need to scalarize it
Value *Index = EEI.getIndexOperand();
if (isa<ConstantInt>(Index))
return false;
return replaceDynamicExtractElementInst(EEI);
}

bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {

unsigned NumOperands = GEPI.getNumOperands();
Expand All @@ -197,8 +305,8 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
return true;
}

Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
LLVMContext &Ctx) {
static Constant *transformInitializer(Constant *Init, Type *OrigType,
Type *NewType, LLVMContext &Ctx) {
// Handle ConstantAggregateZero (zero-initialized constants)
if (isa<ConstantAggregateZero>(Init)) {
return ConstantAggregateZero::get(NewType);
Expand Down Expand Up @@ -257,7 +365,7 @@ static bool findAndReplaceVectors(Module &M) {
for (GlobalVariable &G : M.globals()) {
Type *OrigType = G.getValueType();

Type *NewType = replaceVectorWithArray(OrigType, Ctx);
Type *NewType = equivalentArrayTypeFromVector(OrigType);
if (OrigType != NewType) {
// Create a new global variable with the updated type
// Note: Initializer is set via transformInitializer
Expand Down
127 changes: 127 additions & 0 deletions llvm/test/CodeGen/DirectX/scalarize-dynamic-vector-index.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

; Allocas should be placed in the entry block.
; Allocas should also be reused across multiple insertelement and extractelement instructions for the same vector
define void @alloca_placement_and_reuse(<3 x i32> %v1, <3 x i32> %v2, i32 %a, i32 %i, i32 %j) {
; CHECK-LABEL: define void @alloca_placement_and_reuse(
; CHECK-SAME: <3 x i32> [[V1:%.*]], <3 x i32> [[V2:%.*]], i32 [[A:%.*]], i32 [[I:%.*]], i32 [[J:%.*]]) {
; CHECK-NEXT: [[AL:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[EE1_ALLOCA:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[EE2_ALLOCA:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[EE2_EXTRACT:%.*]] = extractelement <3 x i32> [[V2]], i64 0
; CHECK-NEXT: [[EE2_INDEX:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE2_ALLOCA]], i32 0, i32 0
; CHECK-NEXT: store i32 [[EE2_EXTRACT]], ptr [[EE2_INDEX]], align 4
; CHECK-NEXT: [[EE2_EXTRACT10:%.*]] = extractelement <3 x i32> [[V2]], i64 1
; CHECK-NEXT: [[EE2_INDEX11:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE2_ALLOCA]], i32 0, i32 1
; CHECK-NEXT: store i32 [[EE2_EXTRACT10]], ptr [[EE2_INDEX11]], align 4
; CHECK-NEXT: [[EE2_EXTRACT12:%.*]] = extractelement <3 x i32> [[V2]], i64 2
; CHECK-NEXT: [[EE2_INDEX13:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE2_ALLOCA]], i32 0, i32 2
; CHECK-NEXT: store i32 [[EE2_EXTRACT12]], ptr [[EE2_INDEX13]], align 4
; CHECK-NEXT: [[EE1_EXTRACT:%.*]] = extractelement <3 x i32> [[V1]], i64 0
; CHECK-NEXT: [[EE1_INDEX:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE1_ALLOCA]], i32 0, i32 0
; CHECK-NEXT: store i32 [[EE1_EXTRACT]], ptr [[EE1_INDEX]], align 4
; CHECK-NEXT: [[EE1_EXTRACT1:%.*]] = extractelement <3 x i32> [[V1]], i64 1
; CHECK-NEXT: [[EE1_INDEX2:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE1_ALLOCA]], i32 0, i32 1
; CHECK-NEXT: store i32 [[EE1_EXTRACT1]], ptr [[EE1_INDEX2]], align 4
; CHECK-NEXT: [[EE1_EXTRACT3:%.*]] = extractelement <3 x i32> [[V1]], i64 2
; CHECK-NEXT: [[EE1_INDEX4:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE1_ALLOCA]], i32 0, i32 2
; CHECK-NEXT: store i32 [[EE1_EXTRACT3]], ptr [[EE1_INDEX4]], align 4
; CHECK-NEXT: br label %[[BODY:.*]]
; CHECK: [[BODY]]:
; CHECK-NEXT: [[EE1_INDEX5:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE1_ALLOCA]], i32 0, i32 [[I]]
; CHECK-NEXT: [[EE1_LOAD:%.*]] = load i32, ptr [[EE1_INDEX5]], align 4
; CHECK-NEXT: [[IE1_DYNINDEX:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE1_ALLOCA]], i32 0, i32 [[I]]
; CHECK-NEXT: store i32 [[A]], ptr [[IE1_DYNINDEX]], align 4
; CHECK-NEXT: [[IE1_LOAD:%.*]] = load i32, ptr [[EE1_INDEX]], align 4
; CHECK-NEXT: [[IE1_INSERT:%.*]] = insertelement <3 x i32> poison, i32 [[IE1_LOAD]], i32 0
; CHECK-NEXT: [[IE1_LOAD6:%.*]] = load i32, ptr [[EE1_INDEX2]], align 4
; CHECK-NEXT: [[IE1_INSERT7:%.*]] = insertelement <3 x i32> [[IE1_INSERT]], i32 [[IE1_LOAD6]], i32 1
; CHECK-NEXT: [[IE1_LOAD8:%.*]] = load i32, ptr [[EE1_INDEX4]], align 4
; CHECK-NEXT: [[IE1_INSERT9:%.*]] = insertelement <3 x i32> [[IE1_INSERT7]], i32 [[IE1_LOAD8]], i32 2
; CHECK-NEXT: [[EE2_INDEX14:%.*]] = getelementptr inbounds [3 x i32], ptr [[EE2_ALLOCA]], i32 0, i32 [[J]]
; CHECK-NEXT: [[EE2_LOAD:%.*]] = load i32, ptr [[EE2_INDEX14]], align 4
; CHECK-NEXT: ret void
;
%al = alloca [3 x i32], align 4
br label %body
body:
%ee1 = extractelement <3 x i32> %v1, i32 %i
%ie1 = insertelement <3 x i32> %v1, i32 %a, i32 %i
%ee2 = extractelement <3 x i32> %v2, i32 %j
ret void
}

define float @extract_float_vec_dynamic(<4 x float> %v, i32 %i) {
; CHECK-LABEL: define float @extract_float_vec_dynamic(
; CHECK-SAME: <4 x float> [[V:%.*]], i32 [[I:%.*]]) {
; CHECK-NEXT: [[EE_ALLOCA:%.*]] = alloca [4 x float], align 4
; CHECK-NEXT: [[EE_EXTRACT:%.*]] = extractelement <4 x float> [[V]], i64 0
; CHECK-NEXT: [[EE_INDEX:%.*]] = getelementptr inbounds [4 x float], ptr [[EE_ALLOCA]], i32 0, i32 0
; CHECK-NEXT: store float [[EE_EXTRACT]], ptr [[EE_INDEX]], align 4
; CHECK-NEXT: [[EE_EXTRACT1:%.*]] = extractelement <4 x float> [[V]], i64 1
; CHECK-NEXT: [[EE_INDEX2:%.*]] = getelementptr inbounds [4 x float], ptr [[EE_ALLOCA]], i32 0, i32 1
; CHECK-NEXT: store float [[EE_EXTRACT1]], ptr [[EE_INDEX2]], align 4
; CHECK-NEXT: [[EE_EXTRACT3:%.*]] = extractelement <4 x float> [[V]], i64 2
; CHECK-NEXT: [[EE_INDEX4:%.*]] = getelementptr inbounds [4 x float], ptr [[EE_ALLOCA]], i32 0, i32 2
; CHECK-NEXT: store float [[EE_EXTRACT3]], ptr [[EE_INDEX4]], align 4
; CHECK-NEXT: [[EE_EXTRACT5:%.*]] = extractelement <4 x float> [[V]], i64 3
; CHECK-NEXT: [[EE_INDEX6:%.*]] = getelementptr inbounds [4 x float], ptr [[EE_ALLOCA]], i32 0, i32 3
; CHECK-NEXT: store float [[EE_EXTRACT5]], ptr [[EE_INDEX6]], align 4
; CHECK-NEXT: [[EE_INDEX7:%.*]] = getelementptr inbounds [4 x float], ptr [[EE_ALLOCA]], i32 0, i32 [[I]]
; CHECK-NEXT: [[EE_LOAD:%.*]] = load float, ptr [[EE_INDEX7]], align 4
; CHECK-NEXT: ret float [[EE_LOAD]]
;
%ee = extractelement <4 x float> %v, i32 %i
ret float %ee
}

define <3 x i32> @insert_i32_vec_dynamic(<3 x i32> %v, i32 %a, i32 %i) {
; CHECK-LABEL: define <3 x i32> @insert_i32_vec_dynamic(
; CHECK-SAME: <3 x i32> [[V:%.*]], i32 [[A:%.*]], i32 [[I:%.*]]) {
; CHECK-NEXT: [[IE_ALLOCA:%.*]] = alloca [3 x i32], align 4
; CHECK-NEXT: [[IE_EXTRACT:%.*]] = extractelement <3 x i32> [[V]], i64 0
; CHECK-NEXT: [[IE_INDEX:%.*]] = getelementptr inbounds [3 x i32], ptr [[IE_ALLOCA]], i32 0, i32 0
; CHECK-NEXT: store i32 [[IE_EXTRACT]], ptr [[IE_INDEX]], align 4
; CHECK-NEXT: [[IE_EXTRACT1:%.*]] = extractelement <3 x i32> [[V]], i64 1
; CHECK-NEXT: [[IE_INDEX2:%.*]] = getelementptr inbounds [3 x i32], ptr [[IE_ALLOCA]], i32 0, i32 1
; CHECK-NEXT: store i32 [[IE_EXTRACT1]], ptr [[IE_INDEX2]], align 4
; CHECK-NEXT: [[IE_EXTRACT3:%.*]] = extractelement <3 x i32> [[V]], i64 2
; CHECK-NEXT: [[IE_INDEX4:%.*]] = getelementptr inbounds [3 x i32], ptr [[IE_ALLOCA]], i32 0, i32 2
; CHECK-NEXT: store i32 [[IE_EXTRACT3]], ptr [[IE_INDEX4]], align 4
; CHECK-NEXT: [[IE_DYNINDEX:%.*]] = getelementptr inbounds [3 x i32], ptr [[IE_ALLOCA]], i32 0, i32 [[I]]
; CHECK-NEXT: store i32 [[A]], ptr [[IE_DYNINDEX]], align 4
; CHECK-NEXT: [[IE_LOAD:%.*]] = load i32, ptr [[IE_INDEX]], align 4
; CHECK-NEXT: [[IE_INSERT:%.*]] = insertelement <3 x i32> poison, i32 [[IE_LOAD]], i32 0
; CHECK-NEXT: [[IE_LOAD5:%.*]] = load i32, ptr [[IE_INDEX2]], align 4
; CHECK-NEXT: [[IE_INSERT6:%.*]] = insertelement <3 x i32> [[IE_INSERT]], i32 [[IE_LOAD5]], i32 1
; CHECK-NEXT: [[IE_LOAD7:%.*]] = load i32, ptr [[IE_INDEX4]], align 4
; CHECK-NEXT: [[IE_INSERT8:%.*]] = insertelement <3 x i32> [[IE_INSERT6]], i32 [[IE_LOAD7]], i32 2
; CHECK-NEXT: ret <3 x i32> [[IE_INSERT8]]
;
%ie = insertelement <3 x i32> %v, i32 %a, i32 %i
ret <3 x i32> %ie
}

; An extractelement with a constant index should not be converted to array form
define i16 @extract_i16_vec_constant(<4 x i16> %v) {
; CHECK-LABEL: define i16 @extract_i16_vec_constant(
; CHECK-SAME: <4 x i16> [[V:%.*]]) {
; CHECK-NEXT: [[EE:%.*]] = extractelement <4 x i16> [[V]], i32 1
; CHECK-NEXT: ret i16 [[EE]]
;
%ee = extractelement <4 x i16> %v, i32 1
ret i16 %ee
}

; An insertelement with a constant index should not be converted to array form
define <2 x half> @insert_half_vec_constant(<2 x half> %v, half %a) {
; CHECK-LABEL: define <2 x half> @insert_half_vec_constant(
; CHECK-SAME: <2 x half> [[V:%.*]], half [[A:%.*]]) {
; CHECK-NEXT: [[IE:%.*]] = insertelement <2 x half> [[V]], half [[A]], i32 1
; CHECK-NEXT: ret <2 x half> [[IE]]
;
%ie = insertelement <2 x half> %v, half %a, i32 1
ret <2 x half> %ie
}

Loading