diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 1f2700ac55647..06708cec00cec 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -10,6 +10,7 @@ #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M); class DataScalarizerVisitor : public InstVisitor { public: DataScalarizerVisitor() : GlobalMap() {} - bool visit(Instruction &I); + bool visit(Function &F); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. + bool visitAllocaInst(AllocaInst &AI); bool visitInstruction(Instruction &I) { return false; } bool visitSelectInst(SelectInst &SI) { return false; } bool visitICmpInst(ICmpInst &ICI) { return false; } @@ -67,9 +69,14 @@ class DataScalarizerVisitor : public InstVisitor { DenseMap GlobalMap; }; -bool DataScalarizerVisitor::visit(Instruction &I) { - assert(!GlobalMap.empty()); - return InstVisitor::visit(I); +bool DataScalarizerVisitor::visit(Function &F) { + bool MadeChange = false; + ReversePostOrderTraversal RPOT(&F); + for (BasicBlock *BB : make_early_inc_range(RPOT)) { + for (Instruction &I : make_early_inc_range(*BB)) + MadeChange |= InstVisitor::visit(I); + } + return MadeChange; } GlobalVariable * @@ -83,6 +90,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { return nullptr; // Not found } +// Recursively creates an array version of the given vector type. +static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { + if (auto *VecTy = dyn_cast(T)) + return ArrayType::get(VecTy->getElementType(), + dyn_cast(VecTy)->getNumElements()); + if (auto *ArrayTy = dyn_cast(T)) { + Type *NewElementType = + replaceVectorWithArray(ArrayTy->getElementType(), Ctx); + return ArrayType::get(NewElementType, ArrayTy->getNumElements()); + } + // If it's not a vector or array, return the original type. + return T; +} + +static bool isArrayOfVectors(Type *T) { + if (ArrayType *ArrType = dyn_cast(T)) + return isa(ArrType->getElementType()); + return false; +} + +bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) { + if (!isArrayOfVectors(AI.getAllocatedType())) + return false; + + ArrayType *ArrType = cast(AI.getAllocatedType()); + IRBuilder<> Builder(&AI); + LLVMContext &Ctx = AI.getContext(); + Type *NewType = replaceVectorWithArray(ArrType, Ctx); + AllocaInst *ArrAlloca = + Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize"); + ArrAlloca->setAlignment(AI.getAlign()); + AI.replaceAllUsesWith(ArrAlloca); + AI.eraseFromParent(); + return true; +} + bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) { unsigned NumOperands = LI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { @@ -154,20 +197,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { return true; } -// Recursively Creates and Array like version of the given vector like type. -static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) { - if (auto *VecTy = dyn_cast(T)) - return ArrayType::get(VecTy->getElementType(), - dyn_cast(VecTy)->getNumElements()); - if (auto *ArrayTy = dyn_cast(T)) { - Type *NewElementType = - replaceVectorWithArray(ArrayTy->getElementType(), Ctx); - return ArrayType::get(NewElementType, ArrayTy->getNumElements()); - } - // If it's not a vector or array, return the original type. - return T; -} - Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType, LLVMContext &Ctx) { // Handle ConstantAggregateZero (zero-initialized constants) @@ -253,20 +282,15 @@ static bool findAndReplaceVectors(Module &M) { // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes // type equality. Instead we will use the visitor pattern. Impl.GlobalMap[&G] = NewGlobal; - for (User *U : make_early_inc_range(G.users())) { - if (isa(U) && isa(U)) { - ConstantExpr *CE = cast(U); - for (User *UCE : make_early_inc_range(CE->users())) { - if (Instruction *Inst = dyn_cast(UCE)) - Impl.visit(*Inst); - } - } - if (Instruction *Inst = dyn_cast(U)) - Impl.visit(*Inst); - } } } + for (auto &F : make_early_inc_range(M.functions())) { + if (F.isDeclaration()) + continue; + MadeChange |= Impl.visit(F); + } + // Remove the old globals after the iteration for (auto &[Old, New] : Impl.GlobalMap) { Old->eraseFromParent(); diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll index 25dc2c36b4e1f..2676abec1d8ae 100644 --- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll +++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll @@ -8,12 +8,18 @@ define internal void @main() #1 { ; CHECK-LABEL: define internal void @main() { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16 -; CHECK-NEXT: [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4 -; CHECK-NEXT: [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8 -; CHECK-NEXT: [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16 -; CHECK-NEXT: [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4 -; CHECK-NEXT: [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 1 +; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16 +; CHECK-NEXT: [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1 +; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4 +; CHECK-NEXT: [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2 +; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 2 +; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16 +; CHECK-NEXT: [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1 +; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4 +; CHECK-NEXT: [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2 +; CHECK-NEXT: [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8 ; CHECK-NEXT: ret void ; entry: diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll new file mode 100644 index 0000000000000..4829f3a31791f --- /dev/null +++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll @@ -0,0 +1,10 @@ +; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK +; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK + +; CHECK-LABEL: alloca_2d__vec_test +define void @alloca_2d__vec_test() local_unnamed_addr #2 { + ; SCHECK: alloca [2 x [4 x i32]], align 16 + ; FCHECK: alloca [8 x i32], align 16 + %1 = alloca [2 x <4 x i32>], align 16 + ret void +}