diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 2f028817b45b6..9dfbd3cb68928 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -16,6 +16,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/DXILABI.h" +#include #include namespace llvm { diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h index ca20e6719f3a4..4d2cd183ebcbc 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignatureUtils.h @@ -15,6 +15,7 @@ #define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATUREUTILS_H #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/IntervalMap.h" #include "llvm/Frontend/HLSL/HLSLRootSignature.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/raw_ostream.h" @@ -64,6 +65,62 @@ class MetadataBuilder { SmallVector GeneratedMetadata; }; +// RangeInfo holds the information to correctly construct a ResourceRange +// and retains this information to be used for displaying a better diagnostic +struct RangeInfo { + const static uint32_t Unbounded = ~0u; + + uint32_t LowerBound; + uint32_t UpperBound; +}; + +class ResourceRange { +public: + using MapT = llvm::IntervalMap>; + +private: + MapT Intervals; + +public: + ResourceRange(MapT::Allocator &Allocator) : Intervals(MapT(Allocator)) {} + + // Returns a reference to the first RangeInfo that overlaps with + // [Info.LowerBound;Info.UpperBound], or, std::nullopt if there is no overlap + std::optional getOverlapping(const RangeInfo &Info) const; + + // Return the mapped RangeInfo at X or nullptr if no mapping exists + const RangeInfo *lookup(uint32_t X) const; + + // Insert the required (sub-)intervals such that the interval of [a;b] = + // [Info.LowerBound, Info.UpperBound] is covered and points to a valid + // RangeInfo &. + // + // For instance consider the following chain of inserting RangeInfos with the + // intervals denoting the Lower/Upper-bounds: + // + // A = [0;2] + // insert(A) -> false + // intervals: [0;2] -> &A + // B = [5;7] + // insert(B) -> false + // intervals: [0;2] -> &A, [5;7] -> &B + // C = [4;7] + // insert(C) -> true + // intervals: [0;2] -> &A, [4;7] -> &C + // D = [1;5] + // insert(D) -> true + // intervals: [0;2] -> &A, [3;3] -> &D, [4;7] -> &C + // E = [0;unbounded] + // insert(E) -> true + // intervals: [0;unbounded] -> E + // + // Returns a reference to the first RangeInfo that overlaps with + // [Info.LowerBound;Info.UpperBound], or, std::nullopt if there is no overlap + // (equivalent to getOverlapping) + std::optional insert(const RangeInfo &Info); +}; + } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 5bae72a3986f8..1e198b639cfdc 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -355,6 +355,67 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { return MDNode::get(Ctx, Operands); } +std::optional +ResourceRange::getOverlapping(const RangeInfo &Info) const { + MapT::const_iterator Interval = Intervals.find(Info.LowerBound); + if (!Interval.valid() || Info.UpperBound < Interval.start()) + return std::nullopt; + return Interval.value(); +} + +const RangeInfo *ResourceRange::lookup(uint32_t X) const { + return Intervals.lookup(X, nullptr); +} + +std::optional ResourceRange::insert(const RangeInfo &Info) { + uint32_t LowerBound = Info.LowerBound; + uint32_t UpperBound = Info.UpperBound; + + std::optional Res = std::nullopt; + MapT::iterator Interval = Intervals.begin(); + + while (true) { + if (UpperBound < LowerBound) + break; + + Interval.advanceTo(LowerBound); + if (!Interval.valid()) // No interval found + break; + + // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that + // a <= y implicitly from Intervals.find(LowerBound) + if (UpperBound < Interval.start()) + break; // found interval does not overlap with inserted one + + if (!Res.has_value()) // Update to be the first found intersection + Res = Interval.value(); + + if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) { + // x <= a <= b <= y implies that [a;b] is covered by [x;y] + // -> so we don't need to insert this, report an overlap + return Res; + } else if (LowerBound <= Interval.start() && + Interval.stop() <= UpperBound) { + // a <= x <= y <= b implies that [x;y] is covered by [a;b] + // -> so remove the existing interval that we will cover with the + // overwrite + Interval.erase(); + } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) { + // a < x <= b <= y implies that [a; x] is not covered but [x;b] is + // -> so set b = x - 1 such that [a;x-1] is now the interval to insert + UpperBound = Interval.start() - 1; + } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) { + // a < x <= b <= y implies that [y; b] is not covered but [a;y] is + // -> so set a = y + 1 such that [y+1;b] is now the interval to insert + LowerBound = Interval.stop() + 1; + } + } + + assert(LowerBound <= UpperBound && "Attempting to insert an empty interval"); + Intervals.insert(LowerBound, UpperBound, &Info); + return Res; +} + } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt index 2119642769e3d..4048143b36819 100644 --- a/llvm/unittests/Frontend/CMakeLists.txt +++ b/llvm/unittests/Frontend/CMakeLists.txt @@ -12,6 +12,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(LLVMFrontendTests HLSLRootSignatureDumpTest.cpp + HLSLRootSignatureRangesTest.cpp OpenACCTest.cpp OpenMPContextTest.cpp OpenMPIRBuilderTest.cpp diff --git a/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp b/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp new file mode 100644 index 0000000000000..0ef6fe84f0ec9 --- /dev/null +++ b/llvm/unittests/Frontend/HLSLRootSignatureRangesTest.cpp @@ -0,0 +1,177 @@ +//===------ HLSLRootSignatureRangeTest.cpp - RootSignature Range tests ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h" +#include "gtest/gtest.h" + +using namespace llvm::hlsl::rootsig; + +namespace { + +TEST(HLSLRootSignatureTest, NoOverlappingInsertTests) { + // Ensures that there is never a reported overlap + ResourceRange::MapT::Allocator Allocator; + ResourceRange Range(Allocator); + + RangeInfo A; + A.LowerBound = 0; + A.UpperBound = 3; + EXPECT_EQ(Range.insert(A), std::nullopt); + + RangeInfo B; + B.LowerBound = 4; + B.UpperBound = 7; + EXPECT_EQ(Range.insert(B), std::nullopt); + + RangeInfo C; + C.LowerBound = 10; + C.UpperBound = RangeInfo::Unbounded; + EXPECT_EQ(Range.insert(C), std::nullopt); + + // A = [0;3] + EXPECT_EQ(Range.lookup(0), &A); + EXPECT_EQ(Range.lookup(2), &A); + EXPECT_EQ(Range.lookup(3), &A); + + // B = [4;7] + EXPECT_EQ(Range.lookup(4), &B); + EXPECT_EQ(Range.lookup(5), &B); + EXPECT_EQ(Range.lookup(7), &B); + + EXPECT_EQ(Range.lookup(8), nullptr); + EXPECT_EQ(Range.lookup(9), nullptr); + + // C = [10;unbounded] + EXPECT_EQ(Range.lookup(10), &C); + EXPECT_EQ(Range.lookup(42), &C); + EXPECT_EQ(Range.lookup(98237423), &C); + EXPECT_EQ(Range.lookup(RangeInfo::Unbounded), &C); +} + +TEST(HLSLRootSignatureTest, SingleOverlappingInsertTests) { + // Ensures that we correctly report an overlap when we insert a range that + // overlaps with one other range but does not cover (replace) it + ResourceRange::MapT::Allocator Allocator; + ResourceRange Range(Allocator); + + RangeInfo A; + A.LowerBound = 1; + A.UpperBound = 5; + EXPECT_EQ(Range.insert(A), std::nullopt); + + RangeInfo B; + B.LowerBound = 0; + B.UpperBound = 2; + EXPECT_EQ(Range.insert(B).value(), &A); + + RangeInfo C; + C.LowerBound = 4; + C.UpperBound = RangeInfo::Unbounded; + EXPECT_EQ(Range.insert(C).value(), &A); + + // A = [1;5] + EXPECT_EQ(Range.lookup(1), &A); + EXPECT_EQ(Range.lookup(2), &A); + EXPECT_EQ(Range.lookup(3), &A); + EXPECT_EQ(Range.lookup(4), &A); + EXPECT_EQ(Range.lookup(5), &A); + + // B = [0;0] + EXPECT_EQ(Range.lookup(0), &B); + + // C = [6; unbounded] + EXPECT_EQ(Range.lookup(6), &C); + EXPECT_EQ(Range.lookup(RangeInfo::Unbounded), &C); +} + +TEST(HLSLRootSignatureTest, MultipleOverlappingInsertTests) { + // Ensures that we correctly report an overlap when inserted range + // overlaps more than one range and it does not cover (replace) either + // range. In this case it will just fill in the interval between the two + ResourceRange::MapT::Allocator Allocator; + ResourceRange Range(Allocator); + + RangeInfo A; + A.LowerBound = 0; + A.UpperBound = 2; + EXPECT_EQ(Range.insert(A), std::nullopt); + + RangeInfo B; + B.LowerBound = 4; + B.UpperBound = 6; + EXPECT_EQ(Range.insert(B), std::nullopt); + + RangeInfo C; + C.LowerBound = 1; + C.UpperBound = 5; + EXPECT_EQ(Range.insert(C).value(), &A); + + // A = [0;2] + EXPECT_EQ(Range.lookup(0), &A); + EXPECT_EQ(Range.lookup(1), &A); + EXPECT_EQ(Range.lookup(2), &A); + + // B = [4;6] + EXPECT_EQ(Range.lookup(4), &B); + EXPECT_EQ(Range.lookup(5), &B); + EXPECT_EQ(Range.lookup(6), &B); + + // C = [3;3] + EXPECT_EQ(Range.lookup(3), &C); +} + +TEST(HLSLRootSignatureTest, CoverInsertTests) { + // Ensures that we correctly report an overlap when inserted range + // covers one or more ranges + ResourceRange::MapT::Allocator Allocator; + ResourceRange Range(Allocator); + + RangeInfo A; + A.LowerBound = 0; + A.UpperBound = 2; + EXPECT_EQ(Range.insert(A), std::nullopt); + + RangeInfo B; + B.LowerBound = 4; + B.UpperBound = 5; + EXPECT_EQ(Range.insert(B), std::nullopt); + + // Covers B + RangeInfo C; + C.LowerBound = 4; + C.UpperBound = 6; + EXPECT_EQ(Range.insert(C).value(), &B); + + // A = [0;2] + // C = [4;6] <- covers reference to B + EXPECT_EQ(Range.lookup(0), &A); + EXPECT_EQ(Range.lookup(1), &A); + EXPECT_EQ(Range.lookup(2), &A); + EXPECT_EQ(Range.lookup(3), nullptr); + EXPECT_EQ(Range.lookup(4), &C); + EXPECT_EQ(Range.lookup(5), &C); + EXPECT_EQ(Range.lookup(6), &C); + + // Covers all other ranges + RangeInfo D; + D.LowerBound = 0; + D.UpperBound = 7; + EXPECT_EQ(Range.insert(D).value(), &A); + + // D = [0;7] <- Covers reference to A and C + EXPECT_EQ(Range.lookup(0), &D); + EXPECT_EQ(Range.lookup(1), &D); + EXPECT_EQ(Range.lookup(2), &D); + EXPECT_EQ(Range.lookup(3), &D); + EXPECT_EQ(Range.lookup(4), &D); + EXPECT_EQ(Range.lookup(5), &D); + EXPECT_EQ(Range.lookup(6), &D); + EXPECT_EQ(Range.lookup(7), &D); +} + +} // namespace