Skip to content

Commit ad0e1f3

Browse files
Joao SaffranJoao Saffran
authored andcommitted
address comments
1 parent 81261ff commit ad0e1f3

File tree

3 files changed

+177
-124
lines changed

3 files changed

+177
-124
lines changed

llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
1616

1717
#include "llvm/ADT/StringRef.h"
18+
#include "llvm/BinaryFormat/DXContainer.h"
1819
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
1920
#include "llvm/IR/Constants.h"
2021
#include "llvm/MC/DXContainerRootSignature.h"
22+
#include <cstdint>
2123

2224
namespace llvm {
2325
class LLVMContext;
@@ -27,6 +29,38 @@ class Metadata;
2729
namespace hlsl {
2830
namespace rootsig {
2931

32+
inline dxil::ResourceClass
33+
toResourceClass(dxbc::DescriptorRangeType RangeType) {
34+
using namespace dxbc;
35+
switch (RangeType) {
36+
case DescriptorRangeType::SRV:
37+
return dxil::ResourceClass::SRV;
38+
case DescriptorRangeType::UAV:
39+
return dxil::ResourceClass::UAV;
40+
case DescriptorRangeType::CBV:
41+
return dxil::ResourceClass::CBuffer;
42+
case DescriptorRangeType::Sampler:
43+
return dxil::ResourceClass::Sampler;
44+
}
45+
}
46+
47+
inline dxil::ResourceClass toResourceClass(dxbc::RootParameterType Type) {
48+
using namespace dxbc;
49+
switch (Type) {
50+
case RootParameterType::Constants32Bit:
51+
return dxil::ResourceClass::CBuffer;
52+
case RootParameterType::SRV:
53+
return dxil::ResourceClass::SRV;
54+
case RootParameterType::UAV:
55+
return dxil::ResourceClass::UAV;
56+
case RootParameterType::CBV:
57+
return dxil::ResourceClass::CBuffer;
58+
case dxbc::RootParameterType::DescriptorTable:
59+
break;
60+
}
61+
llvm_unreachable("Unconvertible RootParameterType");
62+
}
63+
3064
template <typename T>
3165
class RootSignatureValidationError
3266
: public ErrorInfo<RootSignatureValidationError<T>> {
@@ -47,6 +81,51 @@ class RootSignatureValidationError
4781
}
4882
};
4983

84+
class TableRegisterOverflowError
85+
: public ErrorInfo<TableRegisterOverflowError> {
86+
public:
87+
static char ID;
88+
dxbc::DescriptorRangeType Type;
89+
uint32_t Register;
90+
uint32_t Space;
91+
92+
TableRegisterOverflowError(dxbc::DescriptorRangeType Type, uint32_t Register,
93+
uint32_t Space)
94+
: Type(Type), Register(Register), Space(Space) {}
95+
96+
void log(raw_ostream &OS) const override {
97+
OS << "Cannot bind resource of type "
98+
<< getResourceClassName(toResourceClass(Type))
99+
<< "(register=" << Register << ", space=" << Space
100+
<< "), it exceeds the maximum allowed register value.";
101+
}
102+
103+
std::error_code convertToErrorCode() const override {
104+
return llvm::inconvertibleErrorCode();
105+
}
106+
};
107+
108+
class TableSamplerMixinError : public ErrorInfo<TableSamplerMixinError> {
109+
public:
110+
static char ID;
111+
dxbc::DescriptorRangeType Type;
112+
uint32_t Location;
113+
114+
TableSamplerMixinError(dxbc::DescriptorRangeType Type, uint32_t Location)
115+
: Type(Type), Location(Location) {}
116+
117+
void log(raw_ostream &OS) const override {
118+
OS << "Samplers cannot be mixed with other "
119+
<< "resource types in a descriptor table, "
120+
<< getResourceClassName(toResourceClass(Type))
121+
<< "(location=" << Location << ")";
122+
}
123+
124+
std::error_code convertToErrorCode() const override {
125+
return llvm::inconvertibleErrorCode();
126+
}
127+
};
128+
50129
class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
51130
public:
52131
static char ID;

llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
1616
#include "llvm/IR/IRBuilder.h"
1717
#include "llvm/IR/Metadata.h"
18+
#include "llvm/Support/Error.h"
1819
#include "llvm/Support/ScopedPrinter.h"
20+
#include <cstdint>
1921

2022
using namespace llvm;
2123

@@ -26,6 +28,8 @@ namespace rootsig {
2628
char GenericRSMetadataError::ID;
2729
char InvalidRSMetadataFormat::ID;
2830
char InvalidRSMetadataValue::ID;
31+
char TableSamplerMixinError::ID;
32+
char TableRegisterOverflowError::ID;
2933
template <typename T> char RootSignatureValidationError<T>::ID;
3034

3135
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
@@ -525,6 +529,83 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
525529
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
526530
}
527531

532+
Error validateDescriptorTableSamplerMixin(mcdxbc::DescriptorTable Table,
533+
uint32_t Location) {
534+
bool HasSampler = false;
535+
bool HasOtherRangeType = false;
536+
dxbc::DescriptorRangeType OtherRangeType;
537+
538+
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
539+
dxbc::DescriptorRangeType RangeType =
540+
static_cast<dxbc::DescriptorRangeType>(Range.RangeType);
541+
542+
if (RangeType == dxbc::DescriptorRangeType::Sampler) {
543+
HasSampler = true;
544+
} else {
545+
HasOtherRangeType = true;
546+
OtherRangeType = RangeType;
547+
}
548+
}
549+
550+
// Samplers cannot be mixed with other resources in a descriptor table.
551+
if (HasSampler && HasOtherRangeType)
552+
return make_error<TableSamplerMixinError>(OtherRangeType, Location);
553+
return Error::success();
554+
}
555+
556+
/** This validation logic was extracted from the DXC codebase
557+
* https://github.com/microsoft/DirectXShaderCompiler/blob/7a1b1df9b50a8350a63756720e85196e0285e664/lib/DxilRootSignature/DxilRootSignatureValidator.cpp#L205
558+
*
559+
* It checks if the registers in a descriptor table are overflowing, meaning,
560+
* they are trying to bind a register larger than MAX_UINT.
561+
* This will usually happen when the descriptor table defined a range after an
562+
* unbounded range, which would lead to an overflow in the register;
563+
* Or if trying append a bunch or really large ranges.
564+
**/
565+
Error validateDescriptorTableRegisterOverflow(mcdxbc::DescriptorTable Table,
566+
uint32_t Location) {
567+
uint64_t AppendingRegister = 0;
568+
569+
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
570+
571+
dxbc::DescriptorRangeType RangeType =
572+
static_cast<dxbc::DescriptorRangeType>(Range.RangeType);
573+
574+
uint64_t Register = AppendingRegister;
575+
576+
// Checks if the current register should be appended to the previous range.
577+
if (Range.OffsetInDescriptorsFromTableStart != ~0U)
578+
Register = Range.OffsetInDescriptorsFromTableStart;
579+
580+
// Check for overflow in the register value.
581+
if (Register > ~0U)
582+
return make_error<TableRegisterOverflowError>(
583+
RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
584+
// Is the current range unbounded?
585+
if (Range.NumDescriptors == ~0U) {
586+
// No ranges should be appended to an unbounded range.
587+
AppendingRegister = (uint64_t)~0U + (uint64_t)1ULL;
588+
} else {
589+
// Is the defined range, overflowing?
590+
uint64_t UpperBound = (uint64_t)Range.BaseShaderRegister +
591+
(uint64_t)Range.NumDescriptors - (uint64_t)1U;
592+
if (UpperBound > ~0U)
593+
return make_error<TableRegisterOverflowError>(
594+
RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
595+
596+
// If we append this range, will it overflow?
597+
uint64_t AppendingUpperBound =
598+
(uint64_t)Register + (uint64_t)Range.NumDescriptors - (uint64_t)1U;
599+
if (AppendingUpperBound > ~0U)
600+
return make_error<TableRegisterOverflowError>(
601+
RangeType, Range.BaseShaderRegister, Range.RegisterSpace);
602+
AppendingRegister = Register + Range.NumDescriptors;
603+
}
604+
}
605+
606+
return Error::success();
607+
}
608+
528609
Error MetadataParser::validateRootSignature(
529610
const mcdxbc::RootSignatureDesc &RSD) {
530611
Error DeferredErrs = Error::success();
@@ -609,6 +690,16 @@ Error MetadataParser::validateRootSignature(
609690
joinErrors(std::move(DeferredErrs),
610691
make_error<RootSignatureValidationError<uint32_t>>(
611692
"DescriptorFlag", Range.Flags));
693+
694+
if (Error Err =
695+
validateDescriptorTableSamplerMixin(Table, Info.Location)) {
696+
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
697+
}
698+
699+
if (Error Err =
700+
validateDescriptorTableRegisterOverflow(Table, Info.Location)) {
701+
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
702+
}
612703
}
613704
break;
614705
}

0 commit comments

Comments
 (0)