Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,8 +1289,8 @@ bool SemaHLSL::handleRootSignatureElements(
VerifyRegister(Loc, Descriptor->Reg.Number);
VerifySpace(Loc, Descriptor->Space);

if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
Version, llvm::to_underlying(Descriptor->Flags)))
if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
Descriptor->Flags))
ReportFlagError(Loc);
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/BinaryFormat/DXContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ enum class StaticBorderColor : uint32_t {

bool isValidBorderColor(uint32_t V);

bool isValidRootDesciptorFlags(uint32_t V);

bool isValidDescriptorRangeFlags(uint32_t V);

bool isValidStaticSamplerFlags(uint32_t V);

LLVM_ABI ArrayRef<EnumEntry<StaticBorderColor>> getStaticBorderColors();

LLVM_ABI PartType parsePartType(StringRef S);
Expand Down
8 changes: 5 additions & 3 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ LLVM_ABI bool verifyRootFlag(uint32_t Flags);
LLVM_ABI bool verifyVersion(uint32_t Version);
LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue);
LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace);
LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version,
dxbc::RootDescriptorFlags Flags);
LLVM_ABI bool verifyRangeType(uint32_t Type);
LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version,
dxil::ResourceClass Type,
dxbc::DescriptorRangeFlags FlagsVal);
LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber);
dxbc::DescriptorRangeFlags Flags);
LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version,
dxbc::StaticSamplerFlags Flags);
LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
LLVM_ABI bool verifyMipLODBias(float MipLODBias);
LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy);
Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/BinaryFormat/DXContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) {
return false;
}

bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) {
using FlagT = dxbc::RootDescriptorFlags;
uint32_t LargestValue =
llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
return V < NextPowerOf2(LargestValue);
}

bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) {
using FlagT = dxbc::DescriptorRangeFlags;
uint32_t LargestValue =
llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
return V < NextPowerOf2(LargestValue);
}

bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) {
using FlagT = dxbc::StaticSamplerFlags;
uint32_t LargestValue =
llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
return V < NextPowerOf2(LargestValue);
}

dxbc::PartType dxbc::parsePartType(StringRef S) {
#define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName)
return StringSwitch<dxbc::PartType>(S)
Expand Down
22 changes: 15 additions & 7 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature(
"RegisterSpace", Descriptor.RegisterSpace));

if (RSD.Version > 1) {
if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
Descriptor.Flags))
bool IsValidFlag =
dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
hlsl::rootsig::verifyRootDescriptorFlag(
RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
Expand All @@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature(
make_error<RootSignatureValidationError<uint32_t>>(
"NumDescriptors", Range.NumDescriptors));

if (!hlsl::rootsig::verifyDescriptorRangeFlag(
RSD.Version, Range.RangeType,
dxbc::DescriptorRangeFlags(Range.Flags)))
bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
hlsl::rootsig::verifyDescriptorRangeFlag(
RSD.Version, Range.RangeType,
dxbc::DescriptorRangeFlags(Range.Flags));
if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
Expand Down Expand Up @@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature(
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"RegisterSpace", Sampler.RegisterSpace));

if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
bool IsValidFlag =
dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
hlsl::rootsig::verifyStaticSamplerFlags(
RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
if (!IsValidFlag)
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
Expand Down
13 changes: 4 additions & 9 deletions llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
return !(RegisterSpace >= 0xFFFFFFF0);
}

bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
bool verifyRootDescriptorFlag(uint32_t Version,
dxbc::RootDescriptorFlags FlagsVal) {
using FlagT = dxbc::RootDescriptorFlags;
FlagT Flags = FlagT(FlagsVal);
if (Version == 1)
Expand All @@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
dxbc::DescriptorRangeFlags Flags) {
using FlagT = dxbc::DescriptorRangeFlags;

const bool IsSampler = (Type == dxil::ResourceClass::Sampler);

if (Version == 1) {
Expand Down Expand Up @@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
return (Flags & ~Mask) == FlagT::None;
}

bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) {
uint32_t LargestValue = llvm::to_underlying(
dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR);
if (FlagsNumber >= NextPowerOf2(LargestValue))
return false;

dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber);
bool verifyStaticSamplerFlags(uint32_t Version,
dxbc::StaticSamplerFlags Flags) {
if (Version <= 2)
return Flags == dxbc::StaticSamplerFlags::None;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s

target triple = "dxil-unknown-shadermodel6.0-compute"

; CHECK: error: Invalid value for DescriptorFlag: 66666
; CHECK-NOT: Root Signature Definitions

define void @main() #0 {
entry:
ret void
}
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
!3 = !{ !5 } ; list of root signature elements
!5 = !{ !"DescriptorTable", i32 0, !6, !7 }
!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 }
!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s

target triple = "dxil-unknown-shadermodel6.0-compute"


; CHECK: error: Invalid value for RootDescriptorFlag: 666
; CHECK-NOT: Root Signature Definitions
define void @main() #0 {
entry:
ret void
}
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
!3 = !{ !5 } ; list of root signature elements
!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666 }