Skip to content

Commit 75c09b7

Browse files
authored
[DirectX] Let data scalarizer pass account for sub-types when updating GEP type (#166200)
This pr lets the `dxil-data-scalarization` account for a GEP with a source type that is a sub-type of the pointer operand type. The pass is updated so that the replaced GEP introduces zero indices such that the result type remains the same (with the vector -> array transform). Please see resolved issue for an annotated example. Resolves: llvm/llvm-project#165473
1 parent 83930be commit 75c09b7

File tree

3 files changed

+187
-16
lines changed

3 files changed

+187
-16
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -304,40 +304,76 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
304304
GEPOperator *GOp = cast<GEPOperator>(&GEPI);
305305
Value *PtrOperand = GOp->getPointerOperand();
306306
Type *NewGEPType = GOp->getSourceElementType();
307-
bool NeedsTransform = false;
308307

309308
// Unwrap GEP ConstantExprs to find the base operand and element type
310-
while (auto *CE = dyn_cast<ConstantExpr>(PtrOperand)) {
311-
if (auto *GEPCE = dyn_cast<GEPOperator>(CE)) {
312-
GOp = GEPCE;
313-
PtrOperand = GEPCE->getPointerOperand();
314-
NewGEPType = GEPCE->getSourceElementType();
315-
} else
316-
break;
309+
while (auto *GEPCE = dyn_cast_or_null<GEPOperator>(
310+
dyn_cast<ConstantExpr>(PtrOperand))) {
311+
GOp = GEPCE;
312+
PtrOperand = GEPCE->getPointerOperand();
313+
NewGEPType = GEPCE->getSourceElementType();
317314
}
318315

316+
Type *const OrigGEPType = NewGEPType;
317+
Value *const OrigOperand = PtrOperand;
318+
319319
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
320320
NewGEPType = NewGlobal->getValueType();
321321
PtrOperand = NewGlobal;
322-
NeedsTransform = true;
323322
} else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
324323
Type *AllocatedType = Alloca->getAllocatedType();
325324
if (isa<ArrayType>(AllocatedType) &&
326-
AllocatedType != GOp->getResultElementType()) {
325+
AllocatedType != GOp->getResultElementType())
327326
NewGEPType = AllocatedType;
328-
NeedsTransform = true;
327+
} else
328+
return false; // Only GEPs into an alloca or global variable are considered
329+
330+
// Defer changing i8 GEP types until dxil-flatten-arrays
331+
if (OrigGEPType->isIntegerTy(8))
332+
NewGEPType = OrigGEPType;
333+
334+
// If the original type is a "sub-type" of the new type, then ensure the gep
335+
// correctly zero-indexes the extra dimensions to keep the offset calculation
336+
// correct.
337+
// Eg:
338+
// i32, [4 x i32] and [8 x [4 x i32]] are sub-types of [8 x [4 x i32]], etc.
339+
//
340+
// So then:
341+
// gep [4 x i32] %idx
342+
// -> gep [8 x [4 x i32]], i32 0, i32 %idx
343+
// gep i32 %idx
344+
// -> gep [8 x [4 x i32]], i32 0, i32 0, i32 %idx
345+
uint32_t MissingDims = 0;
346+
Type *SubType = NewGEPType;
347+
348+
// The new type will be in its array version; so match accordingly.
349+
Type *const GEPArrType = equivalentArrayTypeFromVector(OrigGEPType);
350+
351+
while (SubType != GEPArrType) {
352+
MissingDims++;
353+
354+
ArrayType *ArrType = dyn_cast<ArrayType>(SubType);
355+
if (!ArrType) {
356+
assert(SubType == GEPArrType &&
357+
"GEP uses an DXIL invalid sub-type of alloca/global variable");
358+
break;
329359
}
360+
361+
SubType = ArrType->getElementType();
330362
}
331363

364+
bool NeedsTransform = OrigOperand != PtrOperand ||
365+
OrigGEPType != NewGEPType || MissingDims != 0;
366+
332367
if (!NeedsTransform)
333368
return false;
334369

335-
// Keep scalar GEPs scalar; dxil-flatten-arrays will do flattening later
336-
if (!isa<ArrayType>(GOp->getSourceElementType()))
337-
NewGEPType = GOp->getSourceElementType();
338-
339370
IRBuilder<> Builder(&GEPI);
340-
SmallVector<Value *, MaxVecSize> Indices(GOp->indices());
371+
SmallVector<Value *, MaxVecSize> Indices;
372+
373+
for (uint32_t I = 0; I < MissingDims; I++)
374+
Indices.push_back(Builder.getInt32(0));
375+
llvm::append_range(Indices, GOp->indices());
376+
341377
Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
342378
GOp->getName(), GOp->getNoWrapFlags());
343379

llvm/test/CodeGen/DirectX/scalarize-alloca.ll

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,68 @@ define void @alloca_2d_gep_test() {
4242
%3 = getelementptr inbounds nuw [2 x <2 x i32>], ptr %1, i32 0, i32 %2
4343
ret void
4444
}
45+
46+
; CHECK-LABEL: subtype_array_test
47+
define void @subtype_array_test() {
48+
; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
49+
; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
50+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
51+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
52+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
53+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
54+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
55+
; CHECK: ret void
56+
%arr = alloca [8 x [4 x i32]], align 4
57+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
58+
%gep = getelementptr inbounds nuw [4 x i32], ptr %arr, i32 %i
59+
ret void
60+
}
61+
62+
; CHECK-LABEL: subtype_vector_test
63+
define void @subtype_vector_test() {
64+
; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
65+
; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
66+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
67+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]]
68+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
69+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
70+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
71+
; CHECK: ret void
72+
%arr = alloca [8 x <4 x i32>], align 4
73+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
74+
%gep = getelementptr inbounds nuw <4 x i32>, ptr %arr, i32 %i
75+
ret void
76+
}
77+
78+
; CHECK-LABEL: subtype_scalar_test
79+
define void @subtype_scalar_test() {
80+
; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
81+
; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
82+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
83+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 0, i32 [[tid]]
84+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
85+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
86+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
87+
; CHECK: ret void
88+
%arr = alloca [8 x [4 x i32]], align 4
89+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
90+
%gep = getelementptr inbounds nuw i32, ptr %arr, i32 %i
91+
ret void
92+
}
93+
94+
; CHECK-LABEL: subtype_i8_test
95+
define void @subtype_i8_test() {
96+
; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4
97+
; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4
98+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
99+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i8, ptr [[alloca_val]], i32 [[tid]]
100+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
101+
; FCHECK: [[flatidx_lshr:%.*]] = lshr i32 [[flatidx_mul]], 2
102+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_lshr]]
103+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]]
104+
; CHECK: ret void
105+
%arr = alloca [8 x [4 x i32]], align 4
106+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
107+
%gep = getelementptr inbounds nuw i8, ptr %arr, i32 %i
108+
ret void
109+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; RUN: opt -S -passes='dxil-data-scalarization' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=SCHECK,CHECK
2+
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=FCHECK,CHECK
3+
4+
@"arrayofVecData" = local_unnamed_addr addrspace(3) global [8 x <4 x i32>] zeroinitializer, align 16
5+
@"vecData" = external addrspace(3) global <4 x i32>, align 4
6+
7+
; SCHECK: [[arrayofVecData:@arrayofVecData.*]] = local_unnamed_addr addrspace(3) global [8 x [4 x i32]] zeroinitializer, align 16
8+
; FCHECK: [[arrayofVecData:@arrayofVecData.*]] = local_unnamed_addr addrspace(3) global [32 x i32] zeroinitializer, align 16
9+
; CHECK: [[vecData:@vecData.*]] = external addrspace(3) global [4 x i32], align 4
10+
11+
; CHECK-LABEL: subtype_array_test
12+
define <4 x i32> @subtype_array_test() {
13+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
14+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]]
15+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
16+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
17+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
18+
; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
19+
; CHECK: ret <4 x i32> [[x]]
20+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
21+
%gep = getelementptr inbounds nuw [4 x i32], ptr addrspace(3) @"arrayofVecData", i32 %i
22+
%x = load <4 x i32>, ptr addrspace(3) %gep, align 4
23+
ret <4 x i32> %x
24+
}
25+
26+
; CHECK-LABEL: subtype_vector_test
27+
define <4 x i32> @subtype_vector_test() {
28+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
29+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]]
30+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4
31+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
32+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
33+
; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
34+
; CHECK: ret <4 x i32> [[x]]
35+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
36+
%gep = getelementptr inbounds nuw <4 x i32>, ptr addrspace(3) @"arrayofVecData", i32 %i
37+
%x = load <4 x i32>, ptr addrspace(3) %gep, align 4
38+
ret <4 x i32> %x
39+
}
40+
41+
; CHECK-LABEL: subtype_scalar_test
42+
define <4 x i32> @subtype_scalar_test() {
43+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
44+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 0, i32 [[tid]]
45+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
46+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]]
47+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
48+
; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
49+
; CHECK: ret <4 x i32> [[x]]
50+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
51+
%gep = getelementptr inbounds nuw i32, ptr addrspace(3) @"arrayofVecData", i32 %i
52+
%x = load <4 x i32>, ptr addrspace(3) %gep, align 4
53+
ret <4 x i32> %x
54+
}
55+
56+
; CHECK-LABEL: subtype_i8_test
57+
define <4 x i32> @subtype_i8_test() {
58+
; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0)
59+
; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i8, ptr addrspace(3) [[arrayofVecData]], i32 [[tid]]
60+
; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1
61+
; FCHECK: [[flatidx_lshr:%.*]] = lshr i32 [[flatidx_mul]], 2
62+
; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_lshr]]
63+
; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]
64+
; CHECK: [[x:%.*]] = load <4 x i32>, ptr addrspace(3) [[gep]], align 4
65+
; CHECK: ret <4 x i32> [[x]]
66+
%i = tail call i32 @llvm.dx.thread.id(i32 0)
67+
%gep = getelementptr inbounds nuw i8, ptr addrspace(3) @"arrayofVecData", i32 %i
68+
%x = load <4 x i32>, ptr addrspace(3) %gep, align 4
69+
ret <4 x i32> %x
70+
}

0 commit comments

Comments
 (0)