Skip to content

Commit 43e06e0

Browse files
committed
[DirectX] Scalarize Allocas as part of data scalarization
- DXILDataScalarization should not just be limited to global data - Add a scalarization for alloca - Add ReversePostOrderTraversal of functions and iterate over basic blocks and run DataScalarizerVisitor. - fixes #140143
1 parent 8e53e3b commit 43e06e0

File tree

3 files changed

+76
-35
lines changed

3 files changed

+76
-35
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "DirectX.h"
1111
#include "llvm/ADT/PostOrderIterator.h"
1212
#include "llvm/ADT/STLExtras.h"
13+
#include "llvm/IR/DerivedTypes.h"
1314
#include "llvm/IR/GlobalVariable.h"
1415
#include "llvm/IR/IRBuilder.h"
1516
#include "llvm/IR/InstVisitor.h"
@@ -40,9 +41,10 @@ static bool findAndReplaceVectors(Module &M);
4041
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
4142
public:
4243
DataScalarizerVisitor() : GlobalMap() {}
43-
bool visit(Instruction &I);
44+
bool visit(Function &F);
4445
// InstVisitor methods. They return true if the instruction was scalarized,
4546
// false if nothing changed.
47+
bool visitAllocaInst(AllocaInst &AI);
4648
bool visitInstruction(Instruction &I) { return false; }
4749
bool visitSelectInst(SelectInst &SI) { return false; }
4850
bool visitICmpInst(ICmpInst &ICI) { return false; }
@@ -65,11 +67,17 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
6567
private:
6668
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
6769
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
70+
static bool isArrayOfVectors(Type *T);
6871
};
6972

70-
bool DataScalarizerVisitor::visit(Instruction &I) {
71-
assert(!GlobalMap.empty());
72-
return InstVisitor::visit(I);
73+
bool DataScalarizerVisitor::visit(Function &F) {
74+
bool MadeChange = false;
75+
ReversePostOrderTraversal<Function *> RPOT(&F);
76+
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
77+
for (Instruction &I : make_early_inc_range(*BB))
78+
MadeChange |= InstVisitor::visit(I);
79+
}
80+
return MadeChange;
7381
}
7482

7583
GlobalVariable *
@@ -83,6 +91,42 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
8391
return nullptr; // Not found
8492
}
8593

94+
// Recursively Creates and Array like version of the given vector like type.
95+
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
96+
if (auto *VecTy = dyn_cast<VectorType>(T))
97+
return ArrayType::get(VecTy->getElementType(),
98+
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
99+
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
100+
Type *NewElementType =
101+
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
102+
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
103+
}
104+
// If it's not a vector or array, return the original type.
105+
return T;
106+
}
107+
108+
bool DataScalarizerVisitor::isArrayOfVectors(Type *T) {
109+
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
110+
return isa<VectorType>(ArrType->getElementType());
111+
return false;
112+
}
113+
114+
bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
115+
if (!isArrayOfVectors(AI.getAllocatedType()))
116+
return false;
117+
118+
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
119+
IRBuilder<> Builder(&AI);
120+
LLVMContext &Ctx = AI.getContext();
121+
Type *NewType = replaceVectorWithArray(ArrType, Ctx);
122+
AllocaInst *ArrAlloca =
123+
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
124+
ArrAlloca->setAlignment(AI.getAlign());
125+
AI.replaceAllUsesWith(ArrAlloca);
126+
AI.eraseFromParent();
127+
return true;
128+
}
129+
86130
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
87131
unsigned NumOperands = LI.getNumOperands();
88132
for (unsigned I = 0; I < NumOperands; ++I) {
@@ -154,20 +198,6 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
154198
return true;
155199
}
156200

157-
// Recursively Creates and Array like version of the given vector like type.
158-
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
159-
if (auto *VecTy = dyn_cast<VectorType>(T))
160-
return ArrayType::get(VecTy->getElementType(),
161-
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
162-
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
163-
Type *NewElementType =
164-
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
165-
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
166-
}
167-
// If it's not a vector or array, return the original type.
168-
return T;
169-
}
170-
171201
Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
172202
LLVMContext &Ctx) {
173203
// Handle ConstantAggregateZero (zero-initialized constants)
@@ -253,20 +283,15 @@ static bool findAndReplaceVectors(Module &M) {
253283
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
254284
// type equality. Instead we will use the visitor pattern.
255285
Impl.GlobalMap[&G] = NewGlobal;
256-
for (User *U : make_early_inc_range(G.users())) {
257-
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
258-
ConstantExpr *CE = cast<ConstantExpr>(U);
259-
for (User *UCE : make_early_inc_range(CE->users())) {
260-
if (Instruction *Inst = dyn_cast<Instruction>(UCE))
261-
Impl.visit(*Inst);
262-
}
263-
}
264-
if (Instruction *Inst = dyn_cast<Instruction>(U))
265-
Impl.visit(*Inst);
266-
}
267286
}
268287
}
269288

289+
for (auto &F : make_early_inc_range(M.functions())) {
290+
if (F.isDeclaration())
291+
continue;
292+
MadeChange |= Impl.visit(F);
293+
}
294+
270295
// Remove the old globals after the iteration
271296
for (auto &[Old, New] : Impl.GlobalMap) {
272297
Old->eraseFromParent();

llvm/test/CodeGen/DirectX/scalar-bug-117273.ll

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
define internal void @main() #1 {
99
; CHECK-LABEL: define internal void @main() {
1010
; CHECK-NEXT: [[ENTRY:.*:]]
11-
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
12-
; CHECK-NEXT: [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
13-
; CHECK-NEXT: [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
14-
; CHECK-NEXT: [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
15-
; CHECK-NEXT: [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
16-
; CHECK-NEXT: [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
11+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 1
12+
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
13+
; CHECK-NEXT: [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1
14+
; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4
15+
; CHECK-NEXT: [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2
16+
; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8
17+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 2
18+
; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
19+
; CHECK-NEXT: [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
20+
; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4
21+
; CHECK-NEXT: [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2
22+
; CHECK-NEXT: [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8
1723
; CHECK-NEXT: ret void
1824
;
1925
entry:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=SCHECK
2+
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefix=FCHECK
3+
4+
; CHECK-LABEL: alloca_2d__vec_test
5+
define void @alloca_2d__vec_test() local_unnamed_addr #2 {
6+
; SCHECK: alloca [2 x [4 x i32]], align 16
7+
; FCHECK: alloca [8 x i32], align 16
8+
%1 = alloca [2 x <4 x i32>], align 16
9+
ret void
10+
}

0 commit comments

Comments
 (0)