Skip to content

Commit 25bf86f

Browse files
authored
[SPIRV] Add pass to replace gethandlefromimplicitbinding (llvm#146756)
The HLSL frontend generates call to the intrinsic @llvm.spv.resource.handlefromimplicitbinding to be able to access a resource where the set and binding were not explicitly given in the source code. Determining the correct set and binding cannot be done during Clang's codegen or earlier because in DXIL, they must first remove resource that are not accessed before assigning binding locations to the resource without an explicit binding. We will follow their lead. This is a change from DXC, where implicit binding for SPIR-V are assigned before optimizations. See llvm/wg-hlsl#309
1 parent f538f1a commit 25bf86f

File tree

5 files changed

+238
-0
lines changed

5 files changed

+238
-0
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
2626
SPIRVGlobalRegistry.cpp
2727
SPIRVInstrInfo.cpp
2828
SPIRVInstructionSelector.cpp
29+
SPIRVLegalizeImplicitBinding.cpp
2930
SPIRVStripConvergentIntrinsics.cpp
3031
SPIRVLegalizePointerCast.cpp
3132
SPIRVMergeRegionExitTargets.cpp

llvm/lib/Target/SPIRV/SPIRV.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
2323
FunctionPass *createSPIRVStructurizerPass();
2424
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
2525
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
26+
ModulePass *createSPIRVLegalizeImplicitBindingPass();
2627
FunctionPass *createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM);
2728
FunctionPass *createSPIRVRegularizerPass();
2829
FunctionPass *createSPIRVPreLegalizerCombiner();
@@ -49,6 +50,7 @@ void initializeSPIRVRegularizerPass(PassRegistry &);
4950
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
5051
void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
5152
void initializeSPIRVStripConvergentIntrinsicsPass(PassRegistry &);
53+
void initializeSPIRVLegalizeImplicitBindingPass(PassRegistry &);
5254
} // namespace llvm
5355

5456
#endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
//===- SPIRVLegalizeImplicitBinding.cpp - Legalize implicit bindings ----*- C++
2+
//-*-===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This pass legalizes the @llvm.spv.resource.handlefromimplicitbinding
11+
// intrinsic by replacing it with a call to
12+
// @llvm.spv.resource.handlefrombinding.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "SPIRV.h"
17+
#include "llvm/ADT/BitVector.h"
18+
#include "llvm/ADT/DenseMap.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/IR/IRBuilder.h"
21+
#include "llvm/IR/InstVisitor.h"
22+
#include "llvm/IR/Intrinsics.h"
23+
#include "llvm/IR/IntrinsicsSPIRV.h"
24+
#include "llvm/IR/Module.h"
25+
#include "llvm/Pass.h"
26+
#include <algorithm>
27+
#include <vector>
28+
29+
using namespace llvm;
30+
31+
namespace {
32+
class SPIRVLegalizeImplicitBinding : public ModulePass {
33+
public:
34+
static char ID;
35+
SPIRVLegalizeImplicitBinding() : ModulePass(ID) {}
36+
37+
bool runOnModule(Module &M) override;
38+
39+
private:
40+
void collectBindingInfo(Module &M);
41+
uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
42+
void replaceImplicitBindingCalls(Module &M);
43+
44+
// A map from descriptor set to a bit vector of used binding numbers.
45+
std::vector<BitVector> UsedBindings;
46+
// A list of all implicit binding calls, to be sorted by order ID.
47+
SmallVector<CallInst *, 16> ImplicitBindingCalls;
48+
};
49+
50+
struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
51+
std::vector<BitVector> &UsedBindings;
52+
SmallVector<CallInst *, 16> &ImplicitBindingCalls;
53+
54+
BindingInfoCollector(std::vector<BitVector> &UsedBindings,
55+
SmallVector<CallInst *, 16> &ImplicitBindingCalls)
56+
: UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
57+
}
58+
59+
void visitCallInst(CallInst &CI) {
60+
if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
61+
const uint32_t DescSet =
62+
cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue();
63+
const uint32_t Binding =
64+
cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
65+
66+
if (UsedBindings.size() <= DescSet) {
67+
UsedBindings.resize(DescSet + 1);
68+
UsedBindings[DescSet].resize(64);
69+
}
70+
if (UsedBindings[DescSet].size() <= Binding) {
71+
UsedBindings[DescSet].resize(2 * Binding + 1);
72+
}
73+
UsedBindings[DescSet].set(Binding);
74+
} else if (CI.getIntrinsicID() ==
75+
Intrinsic::spv_resource_handlefromimplicitbinding) {
76+
ImplicitBindingCalls.push_back(&CI);
77+
}
78+
}
79+
};
80+
81+
void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) {
82+
BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
83+
InfoCollector.visit(M);
84+
85+
// Sort the collected calls by their order ID.
86+
std::sort(
87+
ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
88+
[](const CallInst *A, const CallInst *B) {
89+
const uint32_t OrderIdArgIdx = 0;
90+
const uint32_t OrderA =
91+
cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue();
92+
const uint32_t OrderB =
93+
cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue();
94+
return OrderA < OrderB;
95+
});
96+
}
97+
98+
uint32_t SPIRVLegalizeImplicitBinding::getAndReserveFirstUnusedBinding(
99+
uint32_t DescSet) {
100+
if (UsedBindings.size() <= DescSet) {
101+
UsedBindings.resize(DescSet + 1);
102+
UsedBindings[DescSet].resize(64);
103+
}
104+
105+
int NewBinding = UsedBindings[DescSet].find_first_unset();
106+
if (NewBinding == -1) {
107+
NewBinding = UsedBindings[DescSet].size();
108+
UsedBindings[DescSet].resize(2 * NewBinding + 1);
109+
}
110+
111+
UsedBindings[DescSet].set(NewBinding);
112+
return NewBinding;
113+
}
114+
115+
void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) {
116+
for (CallInst *OldCI : ImplicitBindingCalls) {
117+
IRBuilder<> Builder(OldCI);
118+
const uint32_t DescSet =
119+
cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
120+
const uint32_t NewBinding = getAndReserveFirstUnusedBinding(DescSet);
121+
122+
SmallVector<Value *, 8> Args;
123+
Args.push_back(Builder.getInt32(DescSet));
124+
Args.push_back(Builder.getInt32(NewBinding));
125+
126+
// Copy the remaining arguments from the old call.
127+
for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
128+
Args.push_back(OldCI->getArgOperand(i));
129+
}
130+
131+
Function *NewFunc = Intrinsic::getOrInsertDeclaration(
132+
&M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
133+
CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
134+
NewCI->setCallingConv(OldCI->getCallingConv());
135+
136+
OldCI->replaceAllUsesWith(NewCI);
137+
OldCI->eraseFromParent();
138+
}
139+
}
140+
141+
bool SPIRVLegalizeImplicitBinding::runOnModule(Module &M) {
142+
collectBindingInfo(M);
143+
if (ImplicitBindingCalls.empty()) {
144+
return false;
145+
}
146+
147+
replaceImplicitBindingCalls(M);
148+
return true;
149+
}
150+
} // namespace
151+
152+
char SPIRVLegalizeImplicitBinding::ID = 0;
153+
154+
INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding",
155+
"Legalize SPIR-V implicit bindings", false, false)
156+
157+
ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
158+
return new SPIRVLegalizeImplicitBinding();
159+
}

llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ void SPIRVPassConfig::addIRPasses() {
226226
}
227227

228228
void SPIRVPassConfig::addISelPrepare() {
229+
addPass(createSPIRVLegalizeImplicitBindingPass());
229230
addPass(createSPIRVEmitIntrinsicsPass(&getTM<SPIRVTargetMachine>()));
230231
if (TM.getSubtargetImpl()->isLogicalSPIRV())
231232
addPass(createSPIRVLegalizePointerCastPass(&getTM<SPIRVTargetMachine>()));
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-library %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
3+
4+
@.str = private unnamed_addr constant [2 x i8] c"b\00", align 1
5+
@.str.2 = private unnamed_addr constant [2 x i8] c"c\00", align 1
6+
@.str.4 = private unnamed_addr constant [2 x i8] c"d\00", align 1
7+
@.str.6 = private unnamed_addr constant [2 x i8] c"e\00", align 1
8+
@.str.8 = private unnamed_addr constant [2 x i8] c"f\00", align 1
9+
@.str.10 = private unnamed_addr constant [2 x i8] c"g\00", align 1
10+
@.str.12 = private unnamed_addr constant [2 x i8] c"h\00", align 1
11+
@.str.14 = private unnamed_addr constant [2 x i8] c"i\00", align 1
12+
13+
; CHECK-DAG: OpName [[b:%[0-9]+]] "b"
14+
; CHECK-DAG: OpName [[c:%[0-9]+]] "c"
15+
; CHECK-DAG: OpName [[d:%[0-9]+]] "d"
16+
; CHECK-DAG: OpName [[e:%[0-9]+]] "e"
17+
; CHECK-DAG: OpName [[f:%[0-9]+]] "f"
18+
; CHECK-DAG: OpName [[g:%[0-9]+]] "g"
19+
; CHECK-DAG: OpName [[h:%[0-9]+]] "h"
20+
; CHECK-DAG: OpName [[i:%[0-9]+]] "i"
21+
; CHECK-DAG: OpDecorate [[b]] DescriptorSet 0
22+
; CHECK-DAG: OpDecorate [[b]] Binding 1
23+
; CHECK-DAG: OpDecorate [[c]] DescriptorSet 0
24+
; CHECK-DAG: OpDecorate [[c]] Binding 0
25+
; CHECK-DAG: OpDecorate [[d]] DescriptorSet 0
26+
; CHECK-DAG: OpDecorate [[d]] Binding 3
27+
; CHECK-DAG: OpDecorate [[e]] DescriptorSet 0
28+
; CHECK-DAG: OpDecorate [[e]] Binding 2
29+
; CHECK-DAG: OpDecorate [[f]] DescriptorSet 10
30+
; CHECK-DAG: OpDecorate [[f]] Binding 1
31+
; CHECK-DAG: OpDecorate [[g]] DescriptorSet 10
32+
; CHECK-DAG: OpDecorate [[g]] Binding 0
33+
; CHECK-DAG: OpDecorate [[h]] DescriptorSet 10
34+
; CHECK-DAG: OpDecorate [[h]] Binding 3
35+
; CHECK-DAG: OpDecorate [[i]] DescriptorSet 10
36+
; CHECK-DAG: OpDecorate [[i]] Binding 2
37+
38+
39+
define void @main() local_unnamed_addr #0 {
40+
entry:
41+
%0 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefromimplicitbinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr nonnull @.str)
42+
%1 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefrombinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr nonnull @.str.2)
43+
%2 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefromimplicitbinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 1, i32 0, i32 1, i32 0, i1 false, ptr nonnull @.str.4)
44+
%3 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefrombinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 0, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str.6)
45+
%4 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefrombinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 10, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str.8)
46+
%5 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefromimplicitbinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 2, i32 10, i32 1, i32 0, i1 false, ptr nonnull @.str.10)
47+
%6 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefromimplicitbinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 3, i32 10, i32 1, i32 0, i1 false, ptr nonnull @.str.12)
48+
%7 = tail call target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) @llvm.spv.resource.handlefrombinding.tspirv.SignedImage_i32_5_2_0_0_2_0t(i32 10, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str.14)
49+
%8 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %1, i32 0)
50+
%9 = load i32, ptr addrspace(11) %8, align 4
51+
%10 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %2, i32 0)
52+
%11 = load i32, ptr addrspace(11) %10, align 4
53+
%add.i = add nsw i32 %11, %9
54+
%12 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %3, i32 0)
55+
%13 = load i32, ptr addrspace(11) %12, align 4
56+
%add4.i = add nsw i32 %add.i, %13
57+
%14 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %4, i32 0)
58+
%15 = load i32, ptr addrspace(11) %14, align 4
59+
%add6.i = add nsw i32 %add4.i, %15
60+
%16 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %5, i32 0)
61+
%17 = load i32, ptr addrspace(11) %16, align 4
62+
%add8.i = add nsw i32 %add6.i, %17
63+
%18 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %6, i32 0)
64+
%19 = load i32, ptr addrspace(11) %18, align 4
65+
%add10.i = add nsw i32 %add8.i, %19
66+
%20 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %7, i32 0)
67+
%21 = load i32, ptr addrspace(11) %20, align 4
68+
%add12.i = add nsw i32 %add10.i, %21
69+
%22 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.SignedImage_i32_5_2_0_0_2_0t(target("spirv.SignedImage", i32, 5, 2, 0, 0, 2, 0) %0, i32 0)
70+
store i32 %add12.i, ptr addrspace(11) %22, align 4
71+
ret void
72+
}
73+
74+
75+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

0 commit comments

Comments
 (0)