Skip to content

Commit 1ef79ad

Browse files
committed
refactor hlsl binding to support non free bindings
1 parent 4d0fbbe commit 1ef79ad

File tree

4 files changed

+192
-96
lines changed

4 files changed

+192
-96
lines changed

llvm/include/llvm/Frontend/HLSL/HLSLBinding.h

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,59 @@
1818
#include "llvm/Support/Compiler.h"
1919
#include "llvm/Support/DXILABI.h"
2020
#include "llvm/Support/ErrorHandling.h"
21+
#include <cstdint>
2122

2223
namespace llvm {
2324
namespace hlsl {
2425

26+
struct BindingRange {
27+
uint32_t LowerBound;
28+
uint32_t UpperBound;
29+
BindingRange(uint32_t LB, uint32_t UB) : LowerBound(LB), UpperBound(UB) {}
30+
31+
bool overlapsWith(const BindingRange &Other) const {
32+
return !(
33+
(Other.LowerBound < LowerBound && Other.UpperBound <= LowerBound) ||
34+
(Other.LowerBound >= UpperBound && Other.UpperBound > UpperBound));
35+
}
36+
};
37+
38+
struct BaseRegisterSpace {
39+
uint32_t Space;
40+
SmallVector<BindingRange> Ranges;
41+
BaseRegisterSpace(uint32_t Space) : Space(Space) {}
42+
43+
bool operator==(const BaseRegisterSpace &Other) const {
44+
return Space == Other.Space;
45+
}
46+
};
47+
48+
struct FreeRegisterSpace : public BaseRegisterSpace {
49+
using BaseRegisterSpace::BaseRegisterSpace;
50+
51+
FreeRegisterSpace(uint32_t Space) : BaseRegisterSpace(Space) {
52+
Ranges.emplace_back(0, ~0u);
53+
}
54+
// Size == -1 means unbounded array
55+
LLVM_ABI std::optional<uint32_t> findAvailableBinding(int32_t Size);
56+
};
57+
58+
struct BusyRegisterSpace : public BaseRegisterSpace {
59+
using BaseRegisterSpace::BaseRegisterSpace;
60+
61+
BusyRegisterSpace(uint32_t Space) : BaseRegisterSpace(Space) {}
62+
63+
LLVM_ABI bool isBound(const BindingRange &Range) const;
64+
};
65+
66+
template <typename T> struct BindingSpaces {
67+
dxil::ResourceClass RC;
68+
llvm::SmallVector<T> Spaces;
69+
BindingSpaces(dxil::ResourceClass RC) : RC(RC) {}
70+
LLVM_ABI T &getOrInsertSpace(uint32_t Space);
71+
LLVM_ABI std::optional<const T *> contains(uint32_t Space) const;
72+
};
73+
2574
/// BindingInfo represents the ranges of bindings and free space for each
2675
/// `dxil::ResourceClass`. This can represent HLSL-level bindings as well as
2776
/// bindings described in root signatures, and can be used for analysis of
@@ -44,44 +93,14 @@ namespace hlsl {
4493
/// }
4594
class BindingInfo {
4695
public:
47-
struct BindingRange {
48-
uint32_t LowerBound;
49-
uint32_t UpperBound;
50-
BindingRange(uint32_t LB, uint32_t UB) : LowerBound(LB), UpperBound(UB) {}
51-
};
52-
53-
struct RegisterSpace {
54-
uint32_t Space;
55-
SmallVector<BindingRange> FreeRanges;
56-
RegisterSpace(uint32_t Space) : Space(Space) {
57-
FreeRanges.emplace_back(0, ~0u);
58-
}
59-
// Size == -1 means unbounded array
60-
LLVM_ABI std::optional<uint32_t> findAvailableBinding(int32_t Size);
61-
LLVM_ABI bool isBound(const BindingRange &Range) const;
62-
63-
bool operator==(const RegisterSpace &Other) const {
64-
return Space == Other.Space;
65-
}
66-
};
67-
68-
struct BindingSpaces {
69-
dxil::ResourceClass RC;
70-
llvm::SmallVector<RegisterSpace> Spaces;
71-
BindingSpaces(dxil::ResourceClass RC) : RC(RC) {}
72-
LLVM_ABI RegisterSpace &getOrInsertSpace(uint32_t Space);
73-
LLVM_ABI std::optional<const BindingInfo::RegisterSpace *>
74-
contains(uint32_t Space) const;
75-
};
76-
7796
private:
78-
BindingSpaces SRVSpaces{dxil::ResourceClass::SRV};
79-
BindingSpaces UAVSpaces{dxil::ResourceClass::UAV};
80-
BindingSpaces CBufferSpaces{dxil::ResourceClass::CBuffer};
81-
BindingSpaces SamplerSpaces{dxil::ResourceClass::Sampler};
97+
BindingSpaces<FreeRegisterSpace> SRVSpaces{dxil::ResourceClass::SRV};
98+
BindingSpaces<FreeRegisterSpace> UAVSpaces{dxil::ResourceClass::UAV};
99+
BindingSpaces<FreeRegisterSpace> CBufferSpaces{dxil::ResourceClass::CBuffer};
100+
BindingSpaces<FreeRegisterSpace> SamplerSpaces{dxil::ResourceClass::Sampler};
82101

83102
public:
84-
BindingSpaces &getBindingSpaces(dxil::ResourceClass RC) {
103+
BindingSpaces<FreeRegisterSpace> &getBindingSpaces(dxil::ResourceClass RC) {
85104
switch (RC) {
86105
case dxil::ResourceClass::SRV:
87106
return SRVSpaces;
@@ -95,14 +114,46 @@ class BindingInfo {
95114

96115
llvm_unreachable("Invalid resource class");
97116
}
98-
const BindingSpaces &getBindingSpaces(dxil::ResourceClass RC) const {
117+
const BindingSpaces<FreeRegisterSpace> &
118+
getBindingSpaces(dxil::ResourceClass RC) const {
99119
return const_cast<BindingInfo *>(this)->getBindingSpaces(RC);
100120
}
101121

102122
// Size == -1 means unbounded array
103123
LLVM_ABI std::optional<uint32_t>
104124
findAvailableBinding(dxil::ResourceClass RC, uint32_t Space, int32_t Size);
105125

126+
friend class BindingInfoBuilder;
127+
};
128+
129+
class BusyBindingInfo {
130+
private:
131+
BindingSpaces<BusyRegisterSpace> SRVSpaces{dxil::ResourceClass::SRV};
132+
BindingSpaces<BusyRegisterSpace> UAVSpaces{dxil::ResourceClass::UAV};
133+
BindingSpaces<BusyRegisterSpace> CBufferSpaces{dxil::ResourceClass::CBuffer};
134+
BindingSpaces<BusyRegisterSpace> SamplerSpaces{dxil::ResourceClass::Sampler};
135+
136+
public:
137+
public:
138+
BindingSpaces<BusyRegisterSpace> &getBindingSpaces(dxil::ResourceClass RC) {
139+
switch (RC) {
140+
case dxil::ResourceClass::SRV:
141+
return SRVSpaces;
142+
case dxil::ResourceClass::UAV:
143+
return UAVSpaces;
144+
case dxil::ResourceClass::CBuffer:
145+
return CBufferSpaces;
146+
case dxil::ResourceClass::Sampler:
147+
return SamplerSpaces;
148+
}
149+
150+
llvm_unreachable("Invalid resource class");
151+
}
152+
const BindingSpaces<BusyRegisterSpace> &
153+
getBindingSpaces(dxil::ResourceClass RC) const {
154+
return const_cast<BusyBindingInfo *>(this)->getBindingSpaces(RC);
155+
}
156+
106157
LLVM_ABI bool isBound(dxil::ResourceClass RC, uint32_t Space,
107158
const BindingRange &Range) const;
108159

@@ -162,6 +213,11 @@ class BindingInfoBuilder {
162213
[&HasOverlap](auto, auto) { HasOverlap = true; });
163214
}
164215

216+
LLVM_ABI BusyBindingInfo calculateBusyBindingInfo(
217+
llvm::function_ref<void(const BindingInfoBuilder &Builder,
218+
const Binding &Overlapping)>
219+
ReportOverlap);
220+
165221
/// For use in the \c ReportOverlap callback of \c calculateBindingInfo -
166222
/// finds a binding that the \c ReportedBinding overlaps with.
167223
LLVM_ABI const Binding &findOverlapping(const Binding &ReportedBinding) const;

llvm/lib/Frontend/HLSL/HLSLBinding.cpp

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
#include "llvm/ADT/STLExtras.h"
1111
#include "llvm/Support/Error.h"
1212
#include <optional>
13+
#include <utility>
1314

1415
using namespace llvm;
1516
using namespace hlsl;
1617

1718
std::optional<uint32_t>
1819
BindingInfo::findAvailableBinding(dxil::ResourceClass RC, uint32_t Space,
1920
int32_t Size) {
20-
BindingSpaces &BS = getBindingSpaces(RC);
21-
RegisterSpace &RS = BS.getOrInsertSpace(Space);
21+
BindingSpaces<FreeRegisterSpace> &BS = getBindingSpaces(RC);
22+
FreeRegisterSpace &RS = BS.getOrInsertSpace(Space);
2223
return RS.findAvailableBinding(Size);
2324
}
2425

25-
BindingInfo::RegisterSpace &
26-
BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) {
26+
template <typename T> T &BindingSpaces<T>::getOrInsertSpace(uint32_t Space) {
2727
for (auto It = Spaces.begin(), End = Spaces.end(); It != End; ++It) {
2828
if (It->Space == Space)
2929
return *It;
@@ -34,34 +34,33 @@ BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) {
3434
return Spaces.emplace_back(Space);
3535
}
3636

37-
std::optional<const BindingInfo::RegisterSpace *>
38-
BindingInfo::BindingSpaces::contains(uint32_t Space) const {
39-
const BindingInfo::RegisterSpace *It = llvm::find(Spaces, Space);
37+
template <typename T>
38+
std::optional<const T *> BindingSpaces<T>::contains(uint32_t Space) const {
39+
const T *It = llvm::find(Spaces, Space);
4040
if (It == Spaces.end())
4141
return std::nullopt;
4242
return It;
4343
}
4444

45-
std::optional<uint32_t>
46-
BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) {
45+
std::optional<uint32_t> FreeRegisterSpace::findAvailableBinding(int32_t Size) {
4746
assert((Size == -1 || Size > 0) && "invalid size");
4847

49-
if (FreeRanges.empty())
48+
if (Ranges.empty())
5049
return std::nullopt;
5150

5251
// unbounded array
5352
if (Size == -1) {
54-
BindingRange &Last = FreeRanges.back();
53+
BindingRange &Last = Ranges.back();
5554
if (Last.UpperBound != ~0u)
5655
// this space is already occupied by an unbounded array
5756
return std::nullopt;
5857
uint32_t RegSlot = Last.LowerBound;
59-
FreeRanges.pop_back();
58+
Ranges.pop_back();
6059
return RegSlot;
6160
}
6261

6362
// single resource or fixed-size array
64-
for (BindingRange &R : FreeRanges) {
63+
for (BindingRange &R : Ranges) {
6564
// compare the size as uint64_t to prevent overflow for range (0, ~0u)
6665
if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size)
6766
continue;
@@ -76,28 +75,6 @@ BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) {
7675
return std::nullopt;
7776
}
7877

79-
bool BindingInfo::RegisterSpace::isBound(const BindingRange &Range) const {
80-
const BindingRange *It = llvm::lower_bound(
81-
FreeRanges, Range.LowerBound,
82-
[](const BindingRange &R, uint32_t Val) { return R.UpperBound <= Val; });
83-
84-
if (It == FreeRanges.end())
85-
return true;
86-
return ((Range.LowerBound < It->LowerBound) &&
87-
(Range.UpperBound < It->LowerBound)) ||
88-
((Range.LowerBound > It->UpperBound) &&
89-
(Range.UpperBound > It->UpperBound));
90-
}
91-
92-
bool BindingInfo::isBound(dxil::ResourceClass RC, uint32_t Space,
93-
const BindingRange &Range) const {
94-
const BindingSpaces &BS = getBindingSpaces(RC);
95-
std::optional<const BindingInfo::RegisterSpace *> RS = BS.contains(Space);
96-
if (!RS)
97-
return false;
98-
return RS.value()->isBound(Range);
99-
}
100-
10178
BindingInfo BindingInfoBuilder::calculateBindingInfo(
10279
llvm::function_ref<void(const BindingInfoBuilder &Builder,
10380
const Binding &Overlapping)>
@@ -115,16 +92,16 @@ BindingInfo BindingInfoBuilder::calculateBindingInfo(
11592
// Go over the sorted bindings and build up lists of free register ranges
11693
// for each binding type and used spaces. Bindings are sorted by resource
11794
// class, space, and lower bound register slot.
118-
BindingInfo::BindingSpaces *BS =
95+
BindingSpaces<FreeRegisterSpace> *BS =
11996
&Info.getBindingSpaces(dxil::ResourceClass::SRV);
12097
for (const Binding &B : Bindings) {
12198
if (BS->RC != B.RC)
12299
// move to the next resource class spaces
123100
BS = &Info.getBindingSpaces(B.RC);
124101

125-
BindingInfo::RegisterSpace *S = BS->Spaces.empty()
126-
? &BS->Spaces.emplace_back(B.Space)
127-
: &BS->Spaces.back();
102+
FreeRegisterSpace *S = BS->Spaces.empty()
103+
? &BS->Spaces.emplace_back(B.Space)
104+
: &BS->Spaces.back();
128105
assert(S->Space <= B.Space && "bindings not sorted correctly?");
129106
if (B.Space != S->Space)
130107
// add new space
@@ -133,21 +110,21 @@ BindingInfo BindingInfoBuilder::calculateBindingInfo(
133110
// The space is full - there are no free slots left, or the rest of the
134111
// slots are taken by an unbounded array. Report the overlapping to the
135112
// caller.
136-
if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < ~0u) {
113+
if (S->Ranges.empty() || S->Ranges.back().UpperBound < ~0u) {
137114
ReportOverlap(*this, B);
138115
continue;
139116
}
140117
// adjust the last free range lower bound, split it in two, or remove it
141-
BindingInfo::BindingRange &LastFreeRange = S->FreeRanges.back();
118+
BindingRange &LastFreeRange = S->Ranges.back();
142119
if (LastFreeRange.LowerBound == B.LowerBound) {
143120
if (B.UpperBound < ~0u)
144121
LastFreeRange.LowerBound = B.UpperBound + 1;
145122
else
146-
S->FreeRanges.pop_back();
123+
S->Ranges.pop_back();
147124
} else if (LastFreeRange.LowerBound < B.LowerBound) {
148125
LastFreeRange.UpperBound = B.LowerBound - 1;
149126
if (B.UpperBound < ~0u)
150-
S->FreeRanges.emplace_back(B.UpperBound + 1, ~0u);
127+
S->Ranges.emplace_back(B.UpperBound + 1, ~0u);
151128
} else {
152129
// We don't have room here. Report the overlapping binding to the caller
153130
// and mark any extra space this binding would use as unavailable.
@@ -156,7 +133,7 @@ BindingInfo BindingInfoBuilder::calculateBindingInfo(
156133
LastFreeRange.LowerBound =
157134
std::max(LastFreeRange.LowerBound, B.UpperBound + 1);
158135
else
159-
S->FreeRanges.pop_back();
136+
S->Ranges.pop_back();
160137
}
161138
}
162139

@@ -172,3 +149,66 @@ const BindingInfoBuilder::Binding &BindingInfoBuilder::findOverlapping(
172149

173150
llvm_unreachable("Searching for overlap for binding that does not overlap");
174151
}
152+
153+
bool BusyRegisterSpace::isBound(const BindingRange &Range) const {
154+
const BindingRange *It = llvm::lower_bound(
155+
Ranges, Range.LowerBound,
156+
[](const BindingRange &R, uint32_t Val) { return R.UpperBound < Val; });
157+
158+
if (It == Ranges.end())
159+
return false;
160+
return ((Range.LowerBound >= It->LowerBound) &&
161+
(Range.UpperBound <= It->UpperBound));
162+
}
163+
164+
bool BusyBindingInfo::isBound(dxil::ResourceClass RC, uint32_t Space,
165+
const BindingRange &Range) const {
166+
const BindingSpaces<BusyRegisterSpace> &BS = getBindingSpaces(RC);
167+
std::optional<const BusyRegisterSpace *> RS = BS.contains(Space);
168+
if (!RS)
169+
return false;
170+
return RS.value()->isBound(Range);
171+
}
172+
173+
BusyBindingInfo BindingInfoBuilder::calculateBusyBindingInfo(
174+
llvm::function_ref<void(const BindingInfoBuilder &Builder,
175+
const Binding &Overlapping)>
176+
ReportOverlap) {
177+
// sort all the collected bindings
178+
llvm::stable_sort(Bindings);
179+
180+
// remove duplicates
181+
Binding *NewEnd = llvm::unique(Bindings);
182+
if (NewEnd != Bindings.end())
183+
Bindings.erase(NewEnd, Bindings.end());
184+
185+
BusyBindingInfo Info;
186+
187+
BindingSpaces<BusyRegisterSpace> *BS =
188+
&Info.getBindingSpaces(dxil::ResourceClass::SRV);
189+
for (const Binding &B : Bindings) {
190+
if (BS->RC != B.RC)
191+
// move to the next resource class spaces
192+
BS = &Info.getBindingSpaces(B.RC);
193+
194+
BusyRegisterSpace *S = BS->Spaces.empty()
195+
? &BS->Spaces.emplace_back(B.Space)
196+
: &BS->Spaces.back();
197+
assert(S->Space <= B.Space && "bindings not sorted correctly?");
198+
199+
if (B.Space != S->Space)
200+
S = &BS->Spaces.emplace_back(B.Space);
201+
202+
if (!S->Ranges.empty()) {
203+
// check for overlap with the last range only, since the bindings are
204+
// sorted and there cannot be any overlap with earlier ranges.
205+
const BindingRange Back = S->Ranges.back();
206+
if (Back.overlapsWith({B.LowerBound, B.UpperBound})) {
207+
ReportOverlap(*this, B);
208+
continue;
209+
}
210+
}
211+
S->Ranges.emplace_back(B.LowerBound, B.UpperBound);
212+
}
213+
return Info;
214+
}

llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ static void validateRootSignature(Module &M,
227227
Builder.trackBinding(dxil::ResourceClass::Sampler, S.RegisterSpace,
228228
S.ShaderRegister, S.ShaderRegister, &S);
229229

230-
hlsl::BindingInfo Info = Builder.calculateBindingInfo(
230+
hlsl::BusyBindingInfo Info = Builder.calculateBusyBindingInfo(
231231
[&M](const llvm::hlsl::BindingInfoBuilder &Builder,
232232
const llvm::hlsl::BindingInfoBuilder::Binding &ReportedBinding) {
233233
const llvm::hlsl::BindingInfoBuilder::Binding &Overlaping =

0 commit comments

Comments
 (0)