Skip to content

Commit a222856

Browse files
committed
[HLSL][RootSignature] Implement resource register validation
1 parent a926c61 commit a222856

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
1616

1717
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/IntervalMap.h"
1819
#include "llvm/Support/DXILABI.h"
1920
#include "llvm/Support/raw_ostream.h"
2021
#include <variant>
@@ -203,6 +204,61 @@ class MetadataBuilder {
203204
SmallVector<Metadata *> GeneratedMetadata;
204205
};
205206

207+
// RangeInfo holds the information to correctly construct a ResourceRange
208+
// and retains this information to be used for displaying a better diagnostic
209+
struct RangeInfo {
210+
const static uint32_t Unbounded = static_cast<uint32_t>(-1);
211+
212+
uint32_t LowerBound;
213+
uint32_t UpperBound;
214+
};
215+
216+
class ResourceRange {
217+
public:
218+
using IMap = llvm::IntervalMap<uint32_t, const RangeInfo *, 16,
219+
llvm::IntervalMapInfo<uint32_t>>;
220+
221+
private:
222+
IMap Intervals;
223+
224+
public:
225+
ResourceRange(IMap::Allocator &Allocator) : Intervals(IMap(Allocator)) {}
226+
227+
// Returns a reference to the first RangeInfo that overlaps with
228+
// [Info.LowerBound;Info.UpperBound], or, std::nullopt if there is no overlap
229+
std::optional<const RangeInfo *> getOverlapping(const RangeInfo &Info) const;
230+
231+
// Return the mapped RangeInfo at X or nullptr if no mapping exists
232+
const RangeInfo *lookup(uint32_t X) const;
233+
234+
// Insert the required (sub-)intervals such that the interval of [a;b] =
235+
// [Info.LowerBound, Info.UpperBound] is covered and points to a valid
236+
// RangeInfo &.
237+
//
238+
// For instance consider the following chain of inserting RangeInfos with the
239+
// intervals denoting the Lower/Upper-bounds:
240+
//
241+
// A = [0;2]
242+
// insert(A) -> false
243+
// intervals: [0;2] -> &A
244+
// B = [5;7]
245+
// insert(B) -> false
246+
// intervals: [0;2] -> &A, [5;7] -> &B
247+
// C = [4;7]
248+
// insert(C) -> true
249+
// intervals: [0;2] -> &A, [4;7] -> &C
250+
// D = [1;5]
251+
// insert(D) -> true
252+
// intervals: [0;2] -> &A, [3;3] -> &D, [4;7] -> &C
253+
// E = [0;unbounded]
254+
// insert(E) -> true
255+
// intervals: [0;unbounded] -> E
256+
//
257+
// Returns if the first overlapping range when inserting
258+
// (same return as getOverlapping)
259+
std::optional<const RangeInfo *> insert(const RangeInfo &Info);
260+
};
261+
206262
} // namespace rootsig
207263
} // namespace hlsl
208264
} // namespace llvm

llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,67 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause(
227227
});
228228
}
229229

230+
std::optional<const RangeInfo *>
231+
ResourceRange::getOverlapping(const RangeInfo &Info) const {
232+
IMap::const_iterator Interval = Intervals.find(Info.LowerBound);
233+
if (!Interval.valid() || Info.UpperBound < Interval.start())
234+
return std::nullopt;
235+
return Interval.value();
236+
}
237+
238+
const RangeInfo *ResourceRange::lookup(uint32_t X) const {
239+
return Intervals.lookup(X, nullptr);
240+
}
241+
242+
std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
243+
uint32_t LowerBound = Info.LowerBound;
244+
uint32_t UpperBound = Info.UpperBound;
245+
246+
std::optional<const RangeInfo *> Res = std::nullopt;
247+
IMap::iterator Interval = Intervals.begin();
248+
249+
while (true) {
250+
if (UpperBound < LowerBound)
251+
break;
252+
253+
Interval.advanceTo(LowerBound);
254+
if (!Interval.valid()) // No interval found
255+
break;
256+
257+
// Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
258+
// a <= y implicitly from Intervals.find(LowerBound)
259+
if (UpperBound < Interval.start())
260+
break; // found interval does not overlap with inserted one
261+
262+
if (!Res.has_value()) // Update to be the first found intersection
263+
Res = Interval.value();
264+
265+
if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
266+
// x <= a <= b <= y implies that [a;b] is covered by [x;y]
267+
// -> so we don't need to insert this, report an overlap
268+
return Res;
269+
} else if (LowerBound <= Interval.start() &&
270+
Interval.stop() <= UpperBound) {
271+
// a <= x <= y <= b implies that [x;y] is covered by [a;b]
272+
// -> so remove the existing interval that we will cover with the
273+
// overwrite
274+
Interval.erase();
275+
} else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
276+
// a < x <= b <= y implies that [a; x] is not covered but [x;b] is
277+
// -> so set b = x - 1 such that [a;x-1] is now the interval to insert
278+
UpperBound = Interval.start() - 1;
279+
} else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
280+
// a < x <= b <= y implies that [y; b] is not covered but [a;y] is
281+
// -> so set a = y + 1 such that [y+1;b] is now the interval to insert
282+
LowerBound = Interval.stop() + 1;
283+
}
284+
}
285+
286+
assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
287+
Intervals.insert(LowerBound, UpperBound, &Info);
288+
return Res;
289+
}
290+
230291
} // namespace rootsig
231292
} // namespace hlsl
232293
} // namespace llvm

llvm/unittests/Frontend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS
1212

1313
add_llvm_unittest(LLVMFrontendTests
1414
HLSLRootSignatureDumpTest.cpp
15+
HLSLRootSignatureRangesTest.cpp
1516
OpenACCTest.cpp
1617
OpenMPContextTest.cpp
1718
OpenMPIRBuilderTest.cpp
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
//===------ HLSLRootSignatureRangeTest.cpp - RootSignature Range tests ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace llvm::hlsl::rootsig;
13+
14+
namespace {
15+
16+
TEST(HLSLRootSignatureTest, NoOverlappingInsertTests) {
17+
// Ensures that there is never a reported overlap
18+
ResourceRange::IMap::Allocator Allocator;
19+
ResourceRange Range(Allocator);
20+
21+
RangeInfo A;
22+
A.LowerBound = 0;
23+
A.UpperBound = 3;
24+
EXPECT_EQ(Range.insert(A), std::nullopt);
25+
26+
RangeInfo B;
27+
B.LowerBound = 4;
28+
B.UpperBound = 7;
29+
EXPECT_EQ(Range.insert(B), std::nullopt);
30+
31+
RangeInfo C;
32+
C.LowerBound = 10;
33+
C.UpperBound = RangeInfo::Unbounded;
34+
EXPECT_EQ(Range.insert(C), std::nullopt);
35+
36+
// A = [0;3]
37+
EXPECT_EQ(Range.lookup(0), &A);
38+
EXPECT_EQ(Range.lookup(2), &A);
39+
EXPECT_EQ(Range.lookup(3), &A);
40+
41+
// B = [4;7]
42+
EXPECT_EQ(Range.lookup(4), &B);
43+
EXPECT_EQ(Range.lookup(5), &B);
44+
EXPECT_EQ(Range.lookup(7), &B);
45+
46+
EXPECT_EQ(Range.lookup(8), nullptr);
47+
EXPECT_EQ(Range.lookup(9), nullptr);
48+
49+
// C = [10;unbounded]
50+
EXPECT_EQ(Range.lookup(10), &C);
51+
EXPECT_EQ(Range.lookup(42), &C);
52+
EXPECT_EQ(Range.lookup(98237423), &C);
53+
EXPECT_EQ(Range.lookup(RangeInfo::Unbounded), &C);
54+
}
55+
56+
TEST(HLSLRootSignatureTest, SingleOverlappingInsertTests) {
57+
// Ensures that we correctly report an overlap when we insert a range that
58+
// overlaps with one other range but does not cover (replace) it
59+
ResourceRange::IMap::Allocator Allocator;
60+
ResourceRange Range(Allocator);
61+
62+
RangeInfo A;
63+
A.LowerBound = 1;
64+
A.UpperBound = 5;
65+
EXPECT_EQ(Range.insert(A), std::nullopt);
66+
67+
RangeInfo B;
68+
B.LowerBound = 0;
69+
B.UpperBound = 2;
70+
EXPECT_EQ(Range.insert(B).value(), &A);
71+
72+
RangeInfo C;
73+
C.LowerBound = 4;
74+
C.UpperBound = RangeInfo::Unbounded;
75+
EXPECT_EQ(Range.insert(C).value(), &A);
76+
77+
// A = [1;5]
78+
EXPECT_EQ(Range.lookup(1), &A);
79+
EXPECT_EQ(Range.lookup(2), &A);
80+
EXPECT_EQ(Range.lookup(3), &A);
81+
EXPECT_EQ(Range.lookup(4), &A);
82+
EXPECT_EQ(Range.lookup(5), &A);
83+
84+
// B = [0;0]
85+
EXPECT_EQ(Range.lookup(0), &B);
86+
87+
// C = [6; unbounded]
88+
EXPECT_EQ(Range.lookup(6), &C);
89+
EXPECT_EQ(Range.lookup(RangeInfo::Unbounded), &C);
90+
}
91+
92+
TEST(HLSLRootSignatureTest, MultipleOverlappingInsertTests) {
93+
// Ensures that we correctly report an overlap when inserted range
94+
// overlaps more than one range and it does not cover (replace) either
95+
// range. In this case it will just fill in the interval between the two
96+
ResourceRange::IMap::Allocator Allocator;
97+
ResourceRange Range(Allocator);
98+
99+
RangeInfo A;
100+
A.LowerBound = 0;
101+
A.UpperBound = 2;
102+
EXPECT_EQ(Range.insert(A), std::nullopt);
103+
104+
RangeInfo B;
105+
B.LowerBound = 4;
106+
B.UpperBound = 6;
107+
EXPECT_EQ(Range.insert(B), std::nullopt);
108+
109+
RangeInfo C;
110+
C.LowerBound = 1;
111+
C.UpperBound = 5;
112+
EXPECT_EQ(Range.insert(C).value(), &A);
113+
114+
// A = [0;2]
115+
EXPECT_EQ(Range.lookup(0), &A);
116+
EXPECT_EQ(Range.lookup(1), &A);
117+
EXPECT_EQ(Range.lookup(2), &A);
118+
119+
// B = [4;6]
120+
EXPECT_EQ(Range.lookup(4), &B);
121+
EXPECT_EQ(Range.lookup(5), &B);
122+
EXPECT_EQ(Range.lookup(6), &B);
123+
124+
// C = [3;3]
125+
EXPECT_EQ(Range.lookup(3), &C);
126+
}
127+
128+
TEST(HLSLRootSignatureTest, CoverInsertTests) {
129+
// Ensures that we correctly report an overlap when inserted range
130+
// covers one or more ranges
131+
ResourceRange::IMap::Allocator Allocator;
132+
ResourceRange Range(Allocator);
133+
134+
RangeInfo A;
135+
A.LowerBound = 0;
136+
A.UpperBound = 2;
137+
EXPECT_EQ(Range.insert(A), std::nullopt);
138+
139+
RangeInfo B;
140+
B.LowerBound = 4;
141+
B.UpperBound = 5;
142+
EXPECT_EQ(Range.insert(B), std::nullopt);
143+
144+
// Covers B
145+
RangeInfo C;
146+
C.LowerBound = 4;
147+
C.UpperBound = 6;
148+
EXPECT_EQ(Range.insert(C).value(), &B);
149+
150+
// A = [0;2]
151+
// C = [4;6] <- covers reference to B
152+
EXPECT_EQ(Range.lookup(0), &A);
153+
EXPECT_EQ(Range.lookup(1), &A);
154+
EXPECT_EQ(Range.lookup(2), &A);
155+
EXPECT_EQ(Range.lookup(3), nullptr);
156+
EXPECT_EQ(Range.lookup(4), &C);
157+
EXPECT_EQ(Range.lookup(5), &C);
158+
EXPECT_EQ(Range.lookup(6), &C);
159+
160+
// Covers all other ranges
161+
RangeInfo D;
162+
D.LowerBound = 0;
163+
D.UpperBound = 7;
164+
EXPECT_EQ(Range.insert(D).value(), &A);
165+
166+
// D = [0;7] <- Covers reference to A and C
167+
EXPECT_EQ(Range.lookup(0), &D);
168+
EXPECT_EQ(Range.lookup(1), &D);
169+
EXPECT_EQ(Range.lookup(2), &D);
170+
EXPECT_EQ(Range.lookup(3), &D);
171+
EXPECT_EQ(Range.lookup(4), &D);
172+
EXPECT_EQ(Range.lookup(5), &D);
173+
EXPECT_EQ(Range.lookup(6), &D);
174+
EXPECT_EQ(Range.lookup(7), &D);
175+
}
176+
177+
} // namespace

0 commit comments

Comments
 (0)