From 82832bc1604b6677f25d96af028788c0f8648b15 Mon Sep 17 00:00:00 2001 From: Joao Saffran Date: Wed, 1 Oct 2025 12:23:38 -0700 Subject: [PATCH 1/2] fix validation logic --- clang/lib/Sema/SemaHLSL.cpp | 4 ++-- .../Frontend/HLSL/RootSignatureValidations.h | 2 +- .../Frontend/HLSL/RootSignatureMetadata.cpp | 3 +-- .../HLSL/RootSignatureValidations.cpp | 12 ++++++++++- ...escriptorTable-Invalid-Flag-LargeNumber.ll | 20 +++++++++++++++++++ ...ootDescriptor-Invalid-Flags-LargeNumber.ll | 18 +++++++++++++++++ 6 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll create mode 100644 llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 129b03c07c0bd..a2e8afb9bb8ff 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1322,8 +1322,8 @@ bool SemaHLSL::handleRootSignatureElements( ReportError(Loc, 1, 0xfffffffe); } - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type, - Clause->Flags)) + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + Version, Clause->Type, llvm::to_underlying(Clause->Flags))) ReportFlagError(Loc); } } diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 4dd18111b0c9d..10723a181f025 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -32,7 +32,7 @@ LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal); LLVM_ABI bool verifyRangeType(uint32_t Type); LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - dxbc::DescriptorRangeFlags FlagsVal); + uint32_t FlagsVal); LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber); LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors); LLVM_ABI bool verifyMipLODBias(float MipLODBias); diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 5785505ce2b0c..2a22364563f90 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -665,8 +665,7 @@ Error MetadataParser::validateRootSignature( "NumDescriptors", Range.NumDescriptors)); if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, - dxbc::DescriptorRangeFlags(Range.Flags))) + RSD.Version, Range.RangeType, Range.Flags)) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 2c78d622f7f28..e887906955dd2 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -36,6 +36,11 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + if (FlagsVal >= NextPowerOf2(LargestValue)) + return false; + FlagT Flags = FlagT(FlagsVal); if (Version == 1) return Flags == FlagT::DataVolatile; @@ -54,9 +59,14 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { } bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - dxbc::DescriptorRangeFlags Flags) { + uint32_t FlagsVal) { using FlagT = dxbc::DescriptorRangeFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + if (FlagsVal >= NextPowerOf2(LargestValue)) + return false; + FlagT Flags = FlagT(FlagsVal); const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll new file mode 100644 index 0000000000000..c27c87ff057d5 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-Flag-LargeNumber.ll @@ -0,0 +1,20 @@ +; RUN: not opt -passes='print' %s -S -o - 2>&1 | FileCheck %s + +target triple = "dxil-unknown-shadermodel6.0-compute" + +; CHECK: error: Invalid value for DescriptorFlag: 66666 +; CHECK-NOT: Root Signature Definitions + +define void @main() #0 { +entry: + ret void +} +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3, i32 2 } ; function, root signature +!3 = !{ !5 } ; list of root signature elements +!5 = !{ !"DescriptorTable", i32 0, !6, !7 } +!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 } +!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll new file mode 100644 index 0000000000000..898e197c7e0cc --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-Flags-LargeNumber.ll @@ -0,0 +1,18 @@ +; RUN: not opt -passes='print' %s -S -o - 2>&1 | FileCheck %s + +target triple = "dxil-unknown-shadermodel6.0-compute" + + +; CHECK: error: Invalid value for RootDescriptorFlag: 666 +; CHECK-NOT: Root Signature Definitions +define void @main() #0 { +entry: + ret void +} +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3, i32 2 } ; function, root signature +!3 = !{ !5 } ; list of root signature elements +!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666 } From 23bea276768244273ac50fbeb366916a81569ab4 Mon Sep 17 00:00:00 2001 From: Joao Saffran Date: Mon, 6 Oct 2025 13:29:30 -0700 Subject: [PATCH 2/2] addressing comment from bogner --- clang/lib/Sema/SemaHLSL.cpp | 8 +++--- llvm/include/llvm/BinaryFormat/DXContainer.h | 6 +++++ .../Frontend/HLSL/RootSignatureValidations.h | 8 +++--- llvm/lib/BinaryFormat/DXContainer.cpp | 21 ++++++++++++++++ .../Frontend/HLSL/RootSignatureMetadata.cpp | 21 +++++++++++----- .../HLSL/RootSignatureValidations.cpp | 25 ++++--------------- 6 files changed, 56 insertions(+), 33 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index b392ec648598f..a662b72c2a362 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1289,8 +1289,8 @@ bool SemaHLSL::handleRootSignatureElements( VerifyRegister(Loc, Descriptor->Reg.Number); VerifySpace(Loc, Descriptor->Space); - if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag( - Version, llvm::to_underlying(Descriptor->Flags))) + if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version, + Descriptor->Flags)) ReportFlagError(Loc); } else if (const auto *Constants = std::get_if(&Elem)) { @@ -1322,8 +1322,8 @@ bool SemaHLSL::handleRootSignatureElements( ReportError(Loc, 1, 0xfffffffe); } - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( - Version, Clause->Type, llvm::to_underlying(Clause->Flags))) + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(Version, Clause->Type, + Clause->Flags)) ReportFlagError(Loc); } } diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index 0b5646229e8b5..b9a08ce1ca14e 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -248,6 +248,12 @@ enum class StaticBorderColor : uint32_t { bool isValidBorderColor(uint32_t V); +bool isValidRootDesciptorFlags(uint32_t V); + +bool isValidDescriptorRangeFlags(uint32_t V); + +bool isValidStaticSamplerFlags(uint32_t V); + LLVM_ABI ArrayRef> getStaticBorderColors(); LLVM_ABI PartType parsePartType(StringRef S); diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 10723a181f025..7131980e9ff3a 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -28,12 +28,14 @@ LLVM_ABI bool verifyRootFlag(uint32_t Flags); LLVM_ABI bool verifyVersion(uint32_t Version); LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue); LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace); -LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal); +LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags Flags); LLVM_ABI bool verifyRangeType(uint32_t Type); LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - uint32_t FlagsVal); -LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber); + dxbc::DescriptorRangeFlags Flags); +LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags); LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors); LLVM_ABI bool verifyMipLODBias(float MipLODBias); LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy); diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp index b334f86568acb..22f518067b318 100644 --- a/llvm/lib/BinaryFormat/DXContainer.cpp +++ b/llvm/lib/BinaryFormat/DXContainer.cpp @@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) { return false; } +bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) { + using FlagT = dxbc::RootDescriptorFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) { + using FlagT = dxbc::DescriptorRangeFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) { + using FlagT = dxbc::StaticSamplerFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + dxbc::PartType dxbc::parsePartType(StringRef S) { #define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName) return StringSwitch(S) diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index b36d1234a6774..707f0c368e9d8 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature( "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) + bool IsValidFlag = + dxbc::isValidRootDesciptorFlags(Descriptor.Flags) && + hlsl::rootsig::verifyRootDescriptorFlag( + RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( @@ -676,8 +679,11 @@ Error MetadataParser::validateRootSignature( make_error>( "NumDescriptors", Range.NumDescriptors)); - if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, Range.Flags)) + bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) && + hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( @@ -730,8 +736,11 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error>( "RegisterSpace", Sampler.RegisterSpace)); - - if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags)) + bool IsValidFlag = + dxbc::isValidStaticSamplerFlags(Sampler.Flags) && + hlsl::rootsig::verifyStaticSamplerFlags( + RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 47a73060924b0..30408dfda940d 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -34,13 +34,9 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { return !(RegisterSpace >= 0xFFFFFFF0); } -bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { +bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; - uint32_t LargestValue = - llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsVal >= NextPowerOf2(LargestValue)) - return false; - FlagT Flags = FlagT(FlagsVal); if (Version == 1) return Flags == FlagT::DataVolatile; @@ -59,14 +55,8 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { } bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, - uint32_t FlagsVal) { + dxbc::DescriptorRangeFlags Flags) { using FlagT = dxbc::DescriptorRangeFlags; - uint32_t LargestValue = - llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsVal >= NextPowerOf2(LargestValue)) - return false; - - FlagT Flags = FlagT(FlagsVal); const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { @@ -123,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, return (Flags & ~Mask) == FlagT::None; } -bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) { - uint32_t LargestValue = llvm::to_underlying( - dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsNumber >= NextPowerOf2(LargestValue)) - return false; - - dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber); +bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags) { if (Version <= 2) return Flags == dxbc::StaticSamplerFlags::None;