Skip to content

Commit a3240de

Browse files
committed
[HLSL][RootSignature] Implement resource register validation
1 parent ab39042 commit a3240de

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
1616

1717
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/IntervalMap.h"
1819
#include "llvm/Support/DXILABI.h"
1920
#include "llvm/Support/raw_ostream.h"
2021
#include <variant>
@@ -198,6 +199,61 @@ class MetadataBuilder {
198199
SmallVector<Metadata *> GeneratedMetadata;
199200
};
200201

202+
// RangeInfo holds the information to correctly construct a ResourceRange
203+
// and retains this information to be used for displaying a better diagnostic
204+
struct RangeInfo {
205+
const static uint32_t Unbounded = static_cast<uint32_t>(-1);
206+
207+
uint32_t LowerBound;
208+
uint32_t UpperBound;
209+
};
210+
211+
class ResourceRange {
212+
public:
213+
using IMap = llvm::IntervalMap<uint32_t, const RangeInfo *, 16,
214+
llvm::IntervalMapInfo<uint32_t>>;
215+
216+
private:
217+
IMap Intervals;
218+
219+
public:
220+
ResourceRange(IMap::Allocator &Allocator) : Intervals(IMap(Allocator)) {}
221+
222+
// Returns a reference to the first RangeInfo that overlaps with
223+
// [Info.LowerBound;Info.UpperBound], or, std::nullopt if there is no overlap
224+
std::optional<const RangeInfo *> getOverlapping(const RangeInfo &Info) const;
225+
226+
// Return the mapped RangeInfo at X or nullptr if no mapping exists
227+
const RangeInfo *lookup(uint32_t X) const;
228+
229+
// Insert the required (sub-)intervals such that the interval of [a;b] =
230+
// [Info.LowerBound, Info.UpperBound] is covered and points to a valid
231+
// RangeInfo &.
232+
//
233+
// For instance consider the following chain of inserting RangeInfos with the
234+
// intervals denoting the Lower/Upper-bounds:
235+
//
236+
// A = [0;2]
237+
// insert(A) -> false
238+
// intervals: [0;2] -> &A
239+
// B = [5;7]
240+
// insert(B) -> false
241+
// intervals: [0;2] -> &A, [5;7] -> &B
242+
// C = [4;7]
243+
// insert(C) -> true
244+
// intervals: [0;2] -> &A, [4;7] -> &C
245+
// D = [1;5]
246+
// insert(D) -> true
247+
// intervals: [0;2] -> &A, [3;3] -> &D, [4;7] -> &C
248+
// E = [0;unbounded]
249+
// insert(E) -> true
250+
// intervals: [0;unbounded] -> E
251+
//
252+
// Returns if the first overlapping range when inserting
253+
// (same return as getOverlapping)
254+
std::optional<const RangeInfo *> insert(const RangeInfo &Info);
255+
};
256+
201257
} // namespace rootsig
202258
} // namespace hlsl
203259
} // namespace llvm

llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,67 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause(
222222
});
223223
}
224224

225+
std::optional<const RangeInfo *>
226+
ResourceRange::getOverlapping(const RangeInfo &Info) const {
227+
IMap::const_iterator Interval = Intervals.find(Info.LowerBound);
228+
if (!Interval.valid() || Info.UpperBound < Interval.start())
229+
return std::nullopt;
230+
return Interval.value();
231+
}
232+
233+
const RangeInfo *ResourceRange::lookup(uint32_t X) const {
234+
return Intervals.lookup(X, nullptr);
235+
}
236+
237+
std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
238+
uint32_t LowerBound = Info.LowerBound;
239+
uint32_t UpperBound = Info.UpperBound;
240+
241+
std::optional<const RangeInfo *> Res = std::nullopt;
242+
IMap::iterator Interval = Intervals.begin();
243+
244+
while (true) {
245+
if (UpperBound < LowerBound)
246+
break;
247+
248+
Interval.advanceTo(LowerBound);
249+
if (!Interval.valid()) // No interval found
250+
break;
251+
252+
// Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
253+
// a <= y implicitly from Intervals.find(LowerBound)
254+
if (UpperBound < Interval.start())
255+
break; // found interval does not overlap with inserted one
256+
257+
if (!Res.has_value()) // Update to be the first found intersection
258+
Res = Interval.value();
259+
260+
if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
261+
// x <= a <= b <= y implies that [a;b] is covered by [x;y]
262+
// -> so we don't need to insert this, report an overlap
263+
return Res;
264+
} else if (LowerBound <= Interval.start() &&
265+
Interval.stop() <= UpperBound) {
266+
// a <= x <= y <= b implies that [x;y] is covered by [a;b]
267+
// -> so remove the existing interval that we will cover with the
268+
// overwrite
269+
Interval.erase();
270+
} else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
271+
// a < x <= b <= y implies that [a; x] is not covered but [x;b] is
272+
// -> so set b = x - 1 such that [a;x-1] is now the interval to insert
273+
UpperBound = Interval.start() - 1;
274+
} else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
275+
// a < x <= b <= y implies that [y; b] is not covered but [a;y] is
276+
// -> so set a = y + 1 such that [y+1;b] is now the interval to insert
277+
LowerBound = Interval.stop() + 1;
278+
}
279+
}
280+
281+
assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
282+
Intervals.insert(LowerBound, UpperBound, &Info);
283+
return Res;
284+
}
285+
225286
} // namespace rootsig
226287
} // namespace hlsl
227288
} // namespace llvm

llvm/unittests/Frontend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS
1212

1313
add_llvm_unittest(LLVMFrontendTests
1414
HLSLRootSignatureDumpTest.cpp
15+
HLSLRootSignatureRangesTest.cpp
1516
OpenACCTest.cpp
1617
OpenMPContextTest.cpp
1718
OpenMPIRBuilderTest.cpp
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
//===------ HLSLRootSignatureRangeTest.cpp - RootSignature Range tests ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace llvm::hlsl::rootsig;
13+
14+
namespace {
15+
16+
TEST(HLSLRootSignatureTest, NoOverlappingInsertTests) {
17+
// Ensures that there is never a reported overlap
18+
ResourceRange::IMap::Allocator Allocator;
19+
ResourceRange Range(Allocator);
20+
21+
RangeInfo A;
22+
A.LowerBound = 0;
23+
A.UpperBound = 3;
24+
EXPECT_EQ(Range.insert(A), std::nullopt);
25+
26+
RangeInfo B;
27+
B.LowerBound = 4;
28+
B.UpperBound = 7;
29+
EXPECT_EQ(Range.insert(B), std::nullopt);
30+
31+
RangeInfo C;
32+
C.LowerBound = 10;
33+
C.UpperBound = RangeInfo::Unbounded;
34+
EXPECT_EQ(Range.insert(C), std::nullopt);
35+
36+
// A = [0;3]
37+
EXPECT_EQ(Range.lookup(0), &A);
38+
EXPECT_EQ(Range.lookup(2), &A);
39+
EXPECT_EQ(Range.lookup(3), &A);
40+
41+
// B = [4;7]
42+
EXPECT_EQ(Range.lookup(4), &B);
43+
EXPECT_EQ(Range.lookup(5), &B);
44+
EXPECT_EQ(Range.lookup(7), &B);
45+
46+
EXPECT_EQ(Range.lookup(8), nullptr);
47+
EXPECT_EQ(Range.lookup(9), nullptr);
48+
49+
// C = [10;unbounded]
50+
EXPECT_EQ(Range.lookup(10), &C);
51+
EXPECT_EQ(Range.lookup(42), &C);
52+
EXPECT_EQ(Range.lookup(98237423), &C);
53+
EXPECT_EQ(Range.lookup(RangeInfo::Unbounded), &C);
54+
}
55+
56+
TEST(HLSLRootSignatureTest, SingleOverlappingInsertTests) {
57+
// Ensures that we correctly report an overlap when we insert a range that
58+
// overlaps with one other range but does not cover (replace) it
59+
ResourceRange::IMap::Allocator Allocator;
60+
ResourceRange Range(Allocator);
61+
62+
RangeInfo A;
63+
A.LowerBound = 1;
64+
A.UpperBound = 5;
65+
EXPECT_EQ(Range.insert(A), std::nullopt);
66+
67+
RangeInfo B;
68+
B.LowerBound = 0;
69+
B.UpperBound = 2;
70+
EXPECT_EQ(Range.insert(B).value(), &A);
71+
72+
RangeInfo C;
73+
C.LowerBound = 4;
74+
C.UpperBound = RangeInfo::Unbounded;
75+
EXPECT_EQ(Range.insert(C).value(), &A);
76+
77+
// A = [1;5]
78+
EXPECT_EQ(Range.lookup(1), &A);
79+
EXPECT_EQ(Range.lookup(2), &A);
80+
EXPECT_EQ(Range.lookup(3), &A);
81+
EXPECT_EQ(Range.lookup(4), &A);
82+
EXPECT_EQ(Range.lookup(5), &A);
83+
84+
// B = [0;0]
85+
EXPECT_EQ(Range.lookup(0), &B);
86+
87+
// C = [6; unbounded]
88+
EXPECT_EQ(Range.lookup(6), &C);
89+
EXPECT_EQ(Range.lookup(RangeInfo::Unbounded), &C);
90+
}
91+
92+
TEST(HLSLRootSignatureTest, MultipleOverlappingInsertTests) {
93+
// Ensures that we correctly report an overlap when inserted range
94+
// overlaps more than one range and it does not cover (replace) either
95+
// range. In this case it will just fill in the interval between the two
96+
ResourceRange::IMap::Allocator Allocator;
97+
ResourceRange Range(Allocator);
98+
99+
RangeInfo A;
100+
A.LowerBound = 0;
101+
A.UpperBound = 2;
102+
EXPECT_EQ(Range.insert(A), std::nullopt);
103+
104+
RangeInfo B;
105+
B.LowerBound = 4;
106+
B.UpperBound = 6;
107+
EXPECT_EQ(Range.insert(B), std::nullopt);
108+
109+
RangeInfo C;
110+
C.LowerBound = 1;
111+
C.UpperBound = 5;
112+
EXPECT_EQ(Range.insert(C).value(), &A);
113+
114+
// A = [0;2]
115+
EXPECT_EQ(Range.lookup(0), &A);
116+
EXPECT_EQ(Range.lookup(1), &A);
117+
EXPECT_EQ(Range.lookup(2), &A);
118+
119+
// B = [4;6]
120+
EXPECT_EQ(Range.lookup(4), &B);
121+
EXPECT_EQ(Range.lookup(5), &B);
122+
EXPECT_EQ(Range.lookup(6), &B);
123+
124+
// C = [3;3]
125+
EXPECT_EQ(Range.lookup(3), &C);
126+
}
127+
128+
TEST(HLSLRootSignatureTest, CoverInsertTests) {
129+
// Ensures that we correctly report an overlap when inserted range
130+
// covers one or more ranges
131+
ResourceRange::IMap::Allocator Allocator;
132+
ResourceRange Range(Allocator);
133+
134+
RangeInfo A;
135+
A.LowerBound = 0;
136+
A.UpperBound = 2;
137+
EXPECT_EQ(Range.insert(A), std::nullopt);
138+
139+
RangeInfo B;
140+
B.LowerBound = 4;
141+
B.UpperBound = 5;
142+
EXPECT_EQ(Range.insert(B), std::nullopt);
143+
144+
// Covers B
145+
RangeInfo C;
146+
C.LowerBound = 4;
147+
C.UpperBound = 6;
148+
EXPECT_EQ(Range.insert(C).value(), &B);
149+
150+
// A = [0;2]
151+
// C = [4;6] <- covers reference to B
152+
EXPECT_EQ(Range.lookup(0), &A);
153+
EXPECT_EQ(Range.lookup(1), &A);
154+
EXPECT_EQ(Range.lookup(2), &A);
155+
EXPECT_EQ(Range.lookup(3), nullptr);
156+
EXPECT_EQ(Range.lookup(4), &C);
157+
EXPECT_EQ(Range.lookup(5), &C);
158+
EXPECT_EQ(Range.lookup(6), &C);
159+
160+
// Covers all other ranges
161+
RangeInfo D;
162+
D.LowerBound = 0;
163+
D.UpperBound = 7;
164+
EXPECT_EQ(Range.insert(D).value(), &A);
165+
166+
// D = [0;7] <- Covers reference to A and C
167+
EXPECT_EQ(Range.lookup(0), &D);
168+
EXPECT_EQ(Range.lookup(1), &D);
169+
EXPECT_EQ(Range.lookup(2), &D);
170+
EXPECT_EQ(Range.lookup(3), &D);
171+
EXPECT_EQ(Range.lookup(4), &D);
172+
EXPECT_EQ(Range.lookup(5), &D);
173+
EXPECT_EQ(Range.lookup(6), &D);
174+
EXPECT_EQ(Range.lookup(7), &D);
175+
}
176+
177+
} // namespace

0 commit comments

Comments
 (0)