Skip to content

Commit c69a70b

Browse files
authored
[DirectX] NonUniformResourceIndex lowering (#159608)
Introduces `llvm.{dx|svp}.resource.nonuniformindex` intrinsic that will be used when a resource index is not guaranteed to be uniform across threads (HLSL function NonUniformResourceIndex). The DXIL lowering layer looks for this intrinsic call in the resource index calculation, makes sure it is reflected in the NonUniform flag on DXIL create handle ops (`dx.op.createHandle` and `dx.op.createHandleFromBinding`), and then removes it from the module. Closes #155701
1 parent 4f33d7b commit c69a70b

File tree

5 files changed

+212
-10
lines changed

5 files changed

+212
-10
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def int_dx_resource_handlefromimplicitbinding
3939
def int_dx_resource_getpointer
4040
: DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
4141
[IntrNoMem]>;
42+
43+
def int_dx_resource_nonuniformindex
44+
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>;
45+
4246
def int_dx_resource_load_typedbuffer
4347
: DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
4448
[llvm_any_ty, llvm_i32_ty], [IntrReadMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
161161
: DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
162162
[IntrNoMem]>;
163163

164+
def int_spv_resource_nonuniformindex
165+
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>;
166+
164167
// Read a value from the image buffer. It does not translate directly to a
165168
// single OpImageRead because the result type is not necessarily a 4 element
166169
// vector.

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1717
#include "llvm/Analysis/DXILResource.h"
1818
#include "llvm/CodeGen/Passes.h"
19+
#include "llvm/IR/Constant.h"
1920
#include "llvm/IR/DiagnosticInfo.h"
2021
#include "llvm/IR/IRBuilder.h"
2122
#include "llvm/IR/Instruction.h"
@@ -24,6 +25,7 @@
2425
#include "llvm/IR/IntrinsicsDirectX.h"
2526
#include "llvm/IR/Module.h"
2627
#include "llvm/IR/PassManager.h"
28+
#include "llvm/IR/Use.h"
2729
#include "llvm/InitializePasses.h"
2830
#include "llvm/Pass.h"
2931
#include "llvm/Support/ErrorHandling.h"
@@ -42,6 +44,7 @@ class OpLowerer {
4244
DXILResourceTypeMap &DRTM;
4345
const ModuleMetadataInfo &MMDI;
4446
SmallVector<CallInst *> CleanupCasts;
47+
Function *CleanupNURI = nullptr;
4548

4649
public:
4750
OpLowerer(Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM,
@@ -195,6 +198,21 @@ class OpLowerer {
195198
CleanupCasts.clear();
196199
}
197200

201+
void cleanupNonUniformResourceIndexCalls() {
202+
// Replace all NonUniformResourceIndex calls with their argument.
203+
if (!CleanupNURI)
204+
return;
205+
for (User *U : make_early_inc_range(CleanupNURI->users())) {
206+
CallInst *CI = dyn_cast<CallInst>(U);
207+
if (!CI)
208+
continue;
209+
CI->replaceAllUsesWith(CI->getArgOperand(0));
210+
CI->eraseFromParent();
211+
}
212+
CleanupNURI->eraseFromParent();
213+
CleanupNURI = nullptr;
214+
}
215+
198216
// Remove the resource global associated with the handleFromBinding call
199217
// instruction and their uses as they aren't needed anymore.
200218
// TODO: We should verify that all the globals get removed.
@@ -229,6 +247,31 @@ class OpLowerer {
229247
NameGlobal->removeFromParent();
230248
}
231249

250+
bool hasNonUniformIndex(Value *IndexOp) {
251+
if (isa<llvm::Constant>(IndexOp))
252+
return false;
253+
254+
SmallVector<Value *> WorkList;
255+
WorkList.push_back(IndexOp);
256+
257+
while (!WorkList.empty()) {
258+
Value *V = WorkList.pop_back_val();
259+
if (auto *CI = dyn_cast<CallInst>(V)) {
260+
if (CI->getCalledFunction()->getIntrinsicID() ==
261+
Intrinsic::dx_resource_nonuniformindex)
262+
return true;
263+
}
264+
if (auto *U = llvm::dyn_cast<llvm::User>(V)) {
265+
for (llvm::Value *Op : U->operands()) {
266+
if (isa<llvm::Constant>(Op))
267+
continue;
268+
WorkList.push_back(Op);
269+
}
270+
}
271+
}
272+
return false;
273+
}
274+
232275
[[nodiscard]] bool lowerToCreateHandle(Function &F) {
233276
IRBuilder<> &IRB = OpBuilder.getIRB();
234277
Type *Int8Ty = IRB.getInt8Ty();
@@ -250,13 +293,12 @@ class OpLowerer {
250293
IndexOp = IRB.CreateAdd(IndexOp,
251294
ConstantInt::get(Int32Ty, Binding.LowerBound));
252295

253-
// FIXME: The last argument is a NonUniform flag which needs to be set
254-
// based on resource analysis.
255-
// https://github.com/llvm/llvm-project/issues/155701
296+
bool HasNonUniformIndex =
297+
(Binding.Size == 1) ? false : hasNonUniformIndex(IndexOp);
256298
std::array<Value *, 4> Args{
257299
ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
258300
ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
259-
ConstantInt::get(Int1Ty, false)};
301+
ConstantInt::get(Int1Ty, HasNonUniformIndex)};
260302
Expected<CallInst *> OpCall =
261303
OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
262304
if (Error E = OpCall.takeError())
@@ -300,11 +342,10 @@ class OpLowerer {
300342
: Binding.LowerBound + Binding.Size - 1;
301343
Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
302344
Binding.Space, RC);
303-
// FIXME: The last argument is a NonUniform flag which needs to be set
304-
// based on resource analysis.
305-
// https://github.com/llvm/llvm-project/issues/155701
306-
Constant *NonUniform = ConstantInt::get(Int1Ty, false);
307-
std::array<Value *, 3> BindArgs{ResBind, IndexOp, NonUniform};
345+
bool NonUniformIndex =
346+
(Binding.Size == 1) ? false : hasNonUniformIndex(IndexOp);
347+
Constant *NonUniformOp = ConstantInt::get(Int1Ty, NonUniformIndex);
348+
std::array<Value *, 3> BindArgs{ResBind, IndexOp, NonUniformOp};
308349
Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
309350
OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
310351
if (Error E = OpBind.takeError())
@@ -868,6 +909,11 @@ class OpLowerer {
868909
case Intrinsic::dx_resource_getpointer:
869910
HasErrors |= lowerGetPointer(F);
870911
break;
912+
case Intrinsic::dx_resource_nonuniformindex:
913+
assert(!CleanupNURI &&
914+
"overloaded llvm.dx.resource.nonuniformindex intrinsics?");
915+
CleanupNURI = &F;
916+
break;
871917
case Intrinsic::dx_resource_load_typedbuffer:
872918
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
873919
break;
@@ -908,8 +954,10 @@ class OpLowerer {
908954
}
909955
Updated = true;
910956
}
911-
if (Updated && !HasErrors)
957+
if (Updated && !HasErrors) {
912958
cleanupHandleCasts();
959+
cleanupNonUniformResourceIndexCalls();
960+
}
913961

914962
return Updated;
915963
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; RUN: opt -S -passes=dxil-op-lower %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.0-compute"
4+
5+
@A.str = internal unnamed_addr constant [2 x i8] c"A\00", align 1
6+
@B.str = internal unnamed_addr constant [2 x i8] c"A\00", align 1
7+
8+
declare i32 @some_val();
9+
10+
define void @test_buffers_with_nuri() {
11+
12+
%val = call i32 @some_val()
13+
%foo = alloca i32, align 4
14+
15+
; RWBuffer<float> A[10];
16+
;
17+
; A[NonUniformResourceIndex(val)];
18+
19+
%nuri1 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
20+
%res1 = call target("dx.TypedBuffer", float, 1, 0, 0)
21+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %nuri1, ptr @A.str)
22+
; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %val, i1 true) #[[ATTR:.*]]
23+
; CHECK-NOT: @llvm.dx.cast.handle
24+
; CHECK-NOT: @llvm.dx.resource.nonuniformindex
25+
26+
; A[NonUniformResourceIndex(val + 1) % 10];
27+
%add1 = add i32 %val, 1
28+
%nuri2 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %add1)
29+
%rem1 = urem i32 %nuri2, 10
30+
%res2 = call target("dx.TypedBuffer", float, 1, 0, 0)
31+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %rem1, ptr @A.str)
32+
; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %rem1, i1 true) #[[ATTR]]
33+
34+
; A[10 + 3 * NonUniformResourceIndex(GI)];
35+
%mul1 = mul i32 %nuri1, 3
36+
%add2 = add i32 %mul1, 10
37+
%res3 = call target("dx.TypedBuffer", float, 1, 0, 0)
38+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %add2, ptr @A.str)
39+
; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %add2, i1 true) #[[ATTR]]
40+
41+
; NonUniformResourceIndex value going through store & load - the flag is not going to get picked up:
42+
%a = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
43+
store i32 %a, ptr %foo
44+
%b = load i32, ptr %foo
45+
%res4 = call target("dx.TypedBuffer", float, 1, 0, 0)
46+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %b, ptr @A.str)
47+
; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 %b, i1 false) #[[ATTR]]
48+
49+
; NonUniformResourceIndex index value on a single resouce (not an array) - the flag is not going to get picked up:
50+
;
51+
; RWBuffer<float> B : register(u20);
52+
; B[NonUniformResourceIndex(val)];
53+
%nuri3 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
54+
%res5 = call target("dx.TypedBuffer", float, 1, 0, 0)
55+
@llvm.dx.resource.handlefrombinding(i32 20, i32 0, i32 1, i32 %nuri1, ptr @B.str)
56+
; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 %val, i1 false) #[[ATTR]]
57+
58+
; NonUniformResourceIndex on unrelated value - the call is removed:
59+
; foo = NonUniformResourceIndex(val);
60+
%nuri4 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
61+
store i32 %nuri4, ptr %foo
62+
; CHECK: store i32 %val, ptr %foo
63+
; CHECK-NOT: @llvm.dx.resource.nonuniformindex
64+
65+
ret void
66+
}
67+
68+
; CHECK: attributes #[[ATTR]] = {{{.*}} memory(read) {{.*}}}
69+
70+
attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
; RUN: opt -S -passes=dxil-op-lower %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
@A.str = internal unnamed_addr constant [2 x i8] c"A\00", align 1
6+
@B.str = internal unnamed_addr constant [2 x i8] c"A\00", align 1
7+
8+
declare i32 @some_val();
9+
10+
define void @test_buffers_with_nuri() {
11+
12+
%val = call i32 @some_val()
13+
%foo = alloca i32, align 4
14+
15+
; RWBuffer<float> A[10];
16+
;
17+
; A[NonUniformResourceIndex(val)];
18+
19+
%nuri1 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
20+
%res1 = call target("dx.TypedBuffer", float, 1, 0, 0)
21+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %nuri1, ptr @A.str)
22+
; CHECK: %[[RES1:.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 9, i32 0, i8 1 }, i32 %val, i1 true) #[[ATTR:.*]]
23+
; CHECK: call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RES1]], %dx.types.ResourceProperties { i32 4106, i32 265 }) #[[ATTR]]
24+
; CHECK-NOT: @llvm.dx.cast.handle
25+
; CHECK-NOT: @llvm.dx.resource.nonuniformindex
26+
27+
; A[NonUniformResourceIndex(val + 1) % 10];
28+
%add1 = add i32 %val, 1
29+
%nuri2 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %add1)
30+
%rem1 = urem i32 %nuri2, 10
31+
%res2 = call target("dx.TypedBuffer", float, 1, 0, 0)
32+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %rem1, ptr @A.str)
33+
; CHECK: %[[RES2:.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 9, i32 0, i8 1 }, i32 %rem1, i1 true) #[[ATTR]]
34+
; CHECK: call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RES2]], %dx.types.ResourceProperties { i32 4106, i32 265 }) #[[ATTR]]
35+
36+
; A[10 + 3 * NonUniformResourceIndex(GI)];
37+
%mul1 = mul i32 %nuri1, 3
38+
%add2 = add i32 %mul1, 10
39+
%res3 = call target("dx.TypedBuffer", float, 1, 0, 0)
40+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %add2, ptr @A.str)
41+
; CHECK: %[[RES3:.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 9, i32 0, i8 1 }, i32 %add2, i1 true) #[[ATTR]]
42+
; CHECK: %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RES3]], %dx.types.ResourceProperties { i32 4106, i32 265 }) #[[ATTR]]
43+
ret void
44+
45+
; NonUniformResourceIndex value going through store & load: the flag is not going to get picked up
46+
%a = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
47+
store i32 %a, ptr %foo
48+
%b = load i32, ptr %foo
49+
%res4 = call target("dx.TypedBuffer", float, 1, 0, 0)
50+
@llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 10, i32 %b, ptr @A.str)
51+
; CHECK: %[[RES4:.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 9, i32 0, i8 1 }, i32 %b, i1 false) #[[ATTR]]
52+
; CHECK: %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RES4]], %dx.types.ResourceProperties { i32 4106, i32 265 }) #[[ATTR]]
53+
54+
; NonUniformResourceIndex index value on a single resouce (not an array): the flag is not going to get picked up
55+
; RWBuffer<float> B : register(u20);
56+
;
57+
; B[NonUniformResourceIndex(val)];
58+
59+
%nuri3 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
60+
%res5 = call target("dx.TypedBuffer", float, 1, 0, 0)
61+
@llvm.dx.resource.handlefrombinding(i32 20, i32 0, i32 1, i32 %nuri1, ptr @B.str)
62+
; CHECK: %[[RES4:.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 0, i32 0, i32 20, i8 1 }, i32 %val, i1 false) #[[ATTR]]
63+
; CHECK: %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle %[[RES4]], %dx.types.ResourceProperties { i32 4106, i32 265 }) #[[ATTR]]
64+
65+
; NonUniformResourceIndex on unrelated value - the call is removed:
66+
; foo = NonUniformResourceIndex(val);
67+
%nuri4 = tail call noundef i32 @llvm.dx.resource.nonuniformindex(i32 %val)
68+
store i32 %nuri4, ptr %foo
69+
; CHECK: store i32 %val, ptr %foo
70+
; CHECK-NOT: @llvm.dx.resource.nonuniformindex
71+
72+
ret void
73+
}
74+
75+
; CHECK: attributes #[[ATTR]] = {{{.*}} memory(none) {{.*}}}
76+
77+
attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }

0 commit comments

Comments
 (0)