Skip to content

Commit ebfd7d0

Browse files
committed
[DirectX] Support ConstExpr GEPs
- Fixes #150050 - Address the issue of many nested geps - Check for ConstantExpr GEP if we see it check if it needs a global replacement with a new type. Create the new constExpr Gep and set the pointer operand to it. Finally cleanup and remove the old nested geps.
1 parent bbe912f commit ebfd7d0

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,84 @@ bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
300300
return replaceDynamicExtractElementInst(EEI);
301301
}
302302

303+
static void buildConstExprGEPChain(GetElementPtrInst &GEPI, Value *CurrentPtr,
304+
SmallVector<ConstantExpr *, 4> &GEPChain,
305+
IRBuilder<> &Builder) {
306+
// Process the rest of the chain in reverse order (skipping the innermost)
307+
for (int I = GEPChain.size() - 2; I >= 0; I--) {
308+
ConstantExpr *CE = GEPChain[I];
309+
GetElementPtrInst *GEPInst =
310+
cast<GetElementPtrInst>(CE->getAsInstruction());
311+
GEPInst->insertBefore(GEPI.getIterator());
312+
SmallVector<Value *, MaxVecSize> CurrIndices(GEPInst->indices());
313+
314+
// Create a new GEP instruction
315+
Type *SourceTy = GEPInst->getSourceElementType();
316+
CurrentPtr =
317+
Builder.CreateGEP(SourceTy, CurrentPtr, CurrIndices, GEPInst->getName(),
318+
GEPInst->getNoWrapFlags());
319+
320+
// If this is the outermost GEP, update the main GEPI
321+
if (I == 0) {
322+
GEPI.setOperand(GEPI.getPointerOperandIndex(), CurrentPtr);
323+
}
324+
325+
// Clean up the temporary instruction
326+
GEPInst->eraseFromParent();
327+
}
328+
}
329+
303330
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
304331
Value *PtrOperand = GEPI.getPointerOperand();
305332
Type *OrigGEPType = GEPI.getSourceElementType();
306333
Type *NewGEPType = OrigGEPType;
307334
bool NeedsTransform = false;
335+
// Check if the pointer operand is a ConstantExpr GEP
336+
if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
337+
PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
338+
339+
// Collect all nested GEPs in the chain
340+
SmallVector<ConstantExpr *, 4> GEPChain;
341+
Value *BasePointer = PtrOpGEPCE->getOperand(0);
342+
GEPChain.push_back(PtrOpGEPCE);
343+
344+
// Walk up the chain to find all nested GEPs and the base pointer
345+
while (auto *NextGEP = dyn_cast<ConstantExpr>(BasePointer)) {
346+
if (NextGEP->getOpcode() != Instruction::GetElementPtr)
347+
break;
348+
349+
GEPChain.push_back(NextGEP);
350+
BasePointer = NextGEP->getOperand(0);
351+
}
352+
353+
// Check if the base pointer is a global that needs replacement
354+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(BasePointer)) {
355+
IRBuilder<> Builder(&GEPI);
356+
357+
// Create a new GEP for the innermost GEP (last in the chain)
358+
ConstantExpr *InnerGEPCE = GEPChain.back();
359+
GetElementPtrInst *InnerGEP =
360+
cast<GetElementPtrInst>(InnerGEPCE->getAsInstruction());
361+
InnerGEP->insertBefore(GEPI.getIterator());
362+
363+
SmallVector<Value *, MaxVecSize> Indices(InnerGEP->indices());
364+
Type *NewGEPType = NewGlobal->getValueType();
365+
Value *NewInnerGEP =
366+
Builder.CreateGEP(NewGEPType, NewGlobal, Indices, InnerGEP->getName(),
367+
InnerGEP->getNoWrapFlags());
368+
369+
// If there's only one GEP in the chain, update the main GEPI directly
370+
if (GEPChain.size() == 1)
371+
GEPI.setOperand(GEPI.getPointerOperandIndex(), NewInnerGEP);
372+
else
373+
// For multiple GEPs, we need to create a chain of GEPs
374+
buildConstExprGEPChain(GEPI, NewInnerGEP, GEPChain, Builder);
375+
376+
// Clean up the innermost GEP
377+
InnerGEP->eraseFromParent();
378+
return true;
379+
}
380+
}
308381

309382
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
310383
NewGEPType = NewGlobal->getValueType();
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.4-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
2+
; RUN: opt -S -passes='dxil-data-scalarization,function(scalarizer<load-store>),dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.4-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
3+
4+
@aTile = hidden addrspace(3) global [10 x [10 x <4 x i32>]] zeroinitializer, align 16
5+
@bTile = hidden addrspace(3) global [10 x [10 x i32]] zeroinitializer, align 16
6+
@cTile = internal global [2 x [2 x <2 x i32>]] zeroinitializer, align 16
7+
@dTile = internal global [2 x [2 x [2 x <2 x i32>]]] zeroinitializer, align 16
8+
9+
define void @CSMain() {
10+
; CHECK-LABEL: define void @CSMain() {
11+
; CHECK-NEXT: [[ENTRY:.*:]]
12+
; CHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE:%.*]] = alloca [4 x i32], align 16
13+
;
14+
; SCHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [10 x <4 x i32>], ptr addrspace(3) getelementptr inbounds ([10 x [10 x [4 x i32]]], ptr addrspace(3) @aTile.scalarized, i32 0, i32 1), i32 0, i32 2
15+
; SCHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr addrspace(3) [[TMP0]], align 16
16+
; SCHECK-NEXT: store <4 x i32> [[TMP1]], ptr [[AFRAGPACKED_I_SCALARIZE]], align 16
17+
;
18+
; FCHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE_I14:%.*]] = getelementptr [4 x i32], ptr [[AFRAGPACKED_I_SCALARIZE]], i32 0, i32 1
19+
; FCHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE_I25:%.*]] = getelementptr [4 x i32], ptr [[AFRAGPACKED_I_SCALARIZE]], i32 0, i32 2
20+
; FCHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE_I36:%.*]] = getelementptr [4 x i32], ptr [[AFRAGPACKED_I_SCALARIZE]], i32 0, i32 3
21+
; FCHECK-NEXT: [[DOTI07:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([400 x i32], ptr addrspace(3) @aTile.scalarized.1dim, i32 0, i32 48), align 16
22+
; FCHECK-NEXT: [[DOTI119:%.*]] = load i32, ptr addrspace(3) getelementptr ([400 x i32], ptr addrspace(3) @aTile.scalarized.1dim, i32 0, i32 49), align 4
23+
; FCHECK-NEXT: [[DOTI2211:%.*]] = load i32, ptr addrspace(3) getelementptr ([400 x i32], ptr addrspace(3) @aTile.scalarized.1dim, i32 0, i32 50), align 8
24+
; FCHECK-NEXT: [[DOTI3313:%.*]] = load i32, ptr addrspace(3) getelementptr ([400 x i32], ptr addrspace(3) @aTile.scalarized.1dim, i32 0, i32 51), align 4
25+
; FCHECK-NEXT: store i32 [[DOTI07]], ptr [[AFRAGPACKED_I_SCALARIZE]], align 16
26+
; FCHECK-NEXT: store i32 [[DOTI119]], ptr [[AFRAGPACKED_I_SCALARIZE_I14]], align 4
27+
; FCHECK-NEXT: store i32 [[DOTI2211]], ptr [[AFRAGPACKED_I_SCALARIZE_I25]], align 8
28+
; FCHECK-NEXT: store i32 [[DOTI3313]], ptr [[AFRAGPACKED_I_SCALARIZE_I36]], align 4
29+
;
30+
; CHECK-NEXT: ret void
31+
entry:
32+
%aFragPacked.i = alloca <4 x i32>, align 16
33+
%0 = load <4 x i32>, ptr addrspace(3) getelementptr inbounds ([10 x <4 x i32>], ptr addrspace(3) getelementptr inbounds ([10 x [10 x <4 x i32>]], ptr addrspace(3) @aTile, i32 0, i32 1), i32 0, i32 2), align 16
34+
store <4 x i32> %0, ptr %aFragPacked.i, align 16
35+
ret void
36+
}
37+
38+
define void @Main() {
39+
; CHECK-LABEL: define void @Main() {
40+
; CHECK-NEXT: [[ENTRY:.*:]]
41+
; CHECK-NEXT: [[BFRAGPACKED_I:%.*]] = alloca i32, align 16
42+
;
43+
; SCHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [10 x i32], ptr addrspace(3) getelementptr inbounds ([10 x [10 x i32]], ptr addrspace(3) @bTile, i32 0, i32 1), i32 0, i32 1
44+
; SCHECK-NEXT: [[TMP1:%.*]] = load i32, ptr addrspace(3) [[TMP0]], align 16
45+
; SCHECK-NEXT: store i32 [[TMP1]], ptr [[BFRAGPACKED_I]], align 16
46+
;
47+
; FCHECK-NEXT: [[TMP0:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([100 x i32], ptr addrspace(3) @bTile.1dim, i32 0, i32 11), align 16
48+
; FCHECK-NEXT: store i32 [[TMP0]], ptr [[BFRAGPACKED_I]], align 16
49+
;
50+
; CHECK-NEXT: ret void
51+
entry:
52+
%bFragPacked.i = alloca i32, align 16
53+
%0 = load i32, ptr addrspace(3) getelementptr inbounds ([10 x i32], ptr addrspace(3) getelementptr inbounds ([10 x [10 x i32]], ptr addrspace(3) @bTile, i32 0, i32 1), i32 0, i32 1), align 16
54+
store i32 %0, ptr %bFragPacked.i, align 16
55+
ret void
56+
}
57+
58+
define void @global_nested_geps_3d() {
59+
; CHECK-LABEL: define void @global_nested_geps_3d() {
60+
; SCHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <2 x i32>, ptr getelementptr inbounds ([2 x <2 x i32>], ptr getelementptr inbounds ([2 x [2 x [2 x i32]]], ptr @cTile.scalarized, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1
61+
; SCHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
62+
;
63+
; FCHECK-NEXT: [[TMP1:%.*]] = load i32, ptr getelementptr inbounds ([8 x i32], ptr @cTile.scalarized.1dim, i32 0, i32 7), align 4
64+
;
65+
; CHECK-NEXT: ret void
66+
%1 = load i32, i32* getelementptr inbounds (<2 x i32>, <2 x i32>* getelementptr inbounds ([2 x <2 x i32>], [2 x <2 x i32>]* getelementptr inbounds ([2 x [2 x <2 x i32>]], [2 x [2 x <2 x i32>]]* @cTile, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), align 4
67+
ret void
68+
}
69+
70+
define void @global_nested_geps_4d() {
71+
; CHECK-LABEL: define void @global_nested_geps_4d() {
72+
; SCHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <2 x i32>, ptr getelementptr inbounds ([2 x <2 x i32>], ptr getelementptr inbounds ([2 x [2 x <2 x i32>]], ptr getelementptr inbounds ([2 x [2 x [2 x [2 x i32]]]], ptr @dTile.scalarized, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), i32 0, i32 1
73+
; SCHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
74+
;
75+
; FCHECK-NEXT: [[TMP1:%.*]] = load i32, ptr getelementptr inbounds ([16 x i32], ptr @dTile.scalarized.1dim, i32 0, i32 15), align 4
76+
;
77+
; CHECK-NEXT: ret void
78+
%1 = load i32, i32* getelementptr inbounds (<2 x i32>, <2 x i32>* getelementptr inbounds ([2 x <2 x i32>], [2 x <2 x i32>]* getelementptr inbounds ([2 x [2 x <2 x i32>]], [2 x [2 x <2 x i32>]]* getelementptr inbounds ([2 x [2 x [2 x <2 x i32>]]], [2 x [2 x [2 x <2 x i32>]]]* @dTile, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), align 4
79+
ret void
80+
}

0 commit comments

Comments
 (0)