diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index ee4e3cc90118d..9b68a524432cc 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -21,6 +21,24 @@ namespace llvm { namespace hlsl { namespace rootsig { +// Basic verification of RootElements + +bool verifyRootFlag(uint32_t Flags); +bool verifyVersion(uint32_t Version); +bool verifyRegisterValue(uint32_t RegisterValue); +bool verifyRegisterSpace(uint32_t RegisterSpace); +bool verifyDescriptorFlag(uint32_t Flags); +bool verifyRangeType(uint32_t Type); +bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type, + uint32_t FlagsVal); +bool verifySamplerFilter(uint32_t Value); +bool verifyAddress(uint32_t Address); +bool verifyMipLODBias(float MipLODBias); +bool verifyMaxAnisotropy(uint32_t MaxAnisotropy); +bool verifyComparisonFunc(uint32_t ComparisonFunc); +bool verifyBorderColor(uint32_t BorderColor); +bool verifyLOD(float LOD); + struct RangeInfo { const static uint32_t Unbounded = ~0u; diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 9825946d59690..b5b5fc0c74d83 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -12,10 +12,153 @@ #include "llvm/Frontend/HLSL/RootSignatureValidations.h" +#include + namespace llvm { namespace hlsl { namespace rootsig { +bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; } + +bool verifyVersion(uint32_t Version) { return (Version == 1 || Version == 2); } + +bool verifyRegisterValue(uint32_t RegisterValue) { + return RegisterValue != ~0U; +} + +// This Range is reserverved, therefore invalid, according to the spec +// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal +bool verifyRegisterSpace(uint32_t RegisterSpace) { + return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF); +} + +bool verifyDescriptorFlag(uint32_t Flags) { return (Flags & ~0xE) == 0; } + +bool verifyRangeType(uint32_t Type) { + switch (Type) { + case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): + case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): + case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): + case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler): + return true; + }; + + return false; +} + +bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type, + uint32_t FlagsVal) { + using FlagT = dxbc::DescriptorRangeFlags; + FlagT Flags = FlagT(FlagsVal); + + const bool IsSampler = + (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)); + + if (Version == 1) { + // Since the metadata is unversioned, we expect to explicitly see the values + // that map to the version 1 behaviour here. + if (IsSampler) + return Flags == FlagT::DescriptorsVolatile; + return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile); + } + + // The data-specific flags are mutually exclusive. + FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic | + FlagT::DataStaticWhileSetAtExecute; + + if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1) + return false; + + // The descriptor-specific flags are mutually exclusive. + FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks | + FlagT::DescriptorsVolatile; + if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1) + return false; + + // For volatile descriptors, DATA_is never valid. + if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) { + FlagT Mask = FlagT::DescriptorsVolatile; + if (!IsSampler) { + Mask |= FlagT::DataVolatile; + Mask |= FlagT::DataStaticWhileSetAtExecute; + } + return (Flags & ~Mask) == FlagT::None; + } + + // For "KEEPING_BUFFER_BOUNDS_CHECKS" descriptors, + // the other data-specific flags may all be set. + if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) == + FlagT::DescriptorsStaticKeepingBufferBoundsChecks) { + FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks; + if (!IsSampler) { + Mask |= FlagT::DataVolatile; + Mask |= FlagT::DataStatic; + Mask |= FlagT::DataStaticWhileSetAtExecute; + } + return (Flags & ~Mask) == FlagT::None; + } + + // When no descriptor flag is set, any data flag is allowed. + FlagT Mask = FlagT::None; + if (!IsSampler) { + Mask |= FlagT::DataVolatile; + Mask |= FlagT::DataStaticWhileSetAtExecute; + Mask |= FlagT::DataStatic; + } + return (Flags & ~Mask) == FlagT::None; +} + +bool verifySamplerFilter(uint32_t Value) { + switch (Value) { +#define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val): +#include "llvm/BinaryFormat/DXContainerConstants.def" + return true; + } + return false; +} + +// Values allowed here: +// https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax +bool verifyAddress(uint32_t Address) { + switch (Address) { +#define TEXTURE_ADDRESS_MODE(Num, Val) \ + case llvm::to_underlying(dxbc::TextureAddressMode::Val): +#include "llvm/BinaryFormat/DXContainerConstants.def" + return true; + } + return false; +} + +bool verifyMipLODBias(float MipLODBias) { + return MipLODBias >= -16.f && MipLODBias <= 15.99f; +} + +bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) { + return MaxAnisotropy <= 16u; +} + +bool verifyComparisonFunc(uint32_t ComparisonFunc) { + switch (ComparisonFunc) { +#define COMPARISON_FUNC(Num, Val) \ + case llvm::to_underlying(dxbc::ComparisonFunc::Val): +#include "llvm/BinaryFormat/DXContainerConstants.def" + return true; + } + return false; +} + +bool verifyBorderColor(uint32_t BorderColor) { + switch (BorderColor) { +#define STATIC_BORDER_COLOR(Num, Val) \ + case llvm::to_underlying(dxbc::StaticBorderColor::Val): +#include "llvm/BinaryFormat/DXContainerConstants.def" + return true; + } + return false; +} + +bool verifyLOD(float LOD) { return !std::isnan(LOD); } + std::optional ResourceRange::getOverlapping(const RangeInfo &Info) const { MapT::const_iterator Interval = Intervals.find(Info.LowerBound); diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 5e0975736f90d..e46b184a353f1 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/BinaryFormat/DXContainer.h" +#include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" @@ -27,7 +28,6 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include #include #include #include @@ -399,156 +399,13 @@ static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, return HasError; } -static bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; } - -static bool verifyVersion(uint32_t Version) { - return (Version == 1 || Version == 2); -} - -static bool verifyRegisterValue(uint32_t RegisterValue) { - return RegisterValue != ~0U; -} - -// This Range is reserverved, therefore invalid, according to the spec -// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal -static bool verifyRegisterSpace(uint32_t RegisterSpace) { - return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF); -} - -static bool verifyDescriptorFlag(uint32_t Flags) { return (Flags & ~0xE) == 0; } - -static bool verifyRangeType(uint32_t Type) { - switch (Type) { - case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): - case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): - case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): - case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler): - return true; - }; - - return false; -} - -static bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type, - uint32_t FlagsVal) { - using FlagT = dxbc::DescriptorRangeFlags; - FlagT Flags = FlagT(FlagsVal); - - const bool IsSampler = - (Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)); - - if (Version == 1) { - // Since the metadata is unversioned, we expect to explicitly see the values - // that map to the version 1 behaviour here. - if (IsSampler) - return Flags == FlagT::DescriptorsVolatile; - return Flags == (FlagT::DataVolatile | FlagT::DescriptorsVolatile); - } - - // The data-specific flags are mutually exclusive. - FlagT DataFlags = FlagT::DataVolatile | FlagT::DataStatic | - FlagT::DataStaticWhileSetAtExecute; - - if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1) - return false; - - // The descriptor-specific flags are mutually exclusive. - FlagT DescriptorFlags = FlagT::DescriptorsStaticKeepingBufferBoundsChecks | - FlagT::DescriptorsVolatile; - if (popcount(llvm::to_underlying(Flags & DescriptorFlags)) > 1) - return false; - - // For volatile descriptors, DATA_STATIC is never valid. - if ((Flags & FlagT::DescriptorsVolatile) == FlagT::DescriptorsVolatile) { - FlagT Mask = FlagT::DescriptorsVolatile; - if (!IsSampler) { - Mask |= FlagT::DataVolatile; - Mask |= FlagT::DataStaticWhileSetAtExecute; - } - return (Flags & ~Mask) == FlagT::None; - } - - // For "STATIC_KEEPING_BUFFER_BOUNDS_CHECKS" descriptors, - // the other data-specific flags may all be set. - if ((Flags & FlagT::DescriptorsStaticKeepingBufferBoundsChecks) == - FlagT::DescriptorsStaticKeepingBufferBoundsChecks) { - FlagT Mask = FlagT::DescriptorsStaticKeepingBufferBoundsChecks; - if (!IsSampler) { - Mask |= FlagT::DataVolatile; - Mask |= FlagT::DataStatic; - Mask |= FlagT::DataStaticWhileSetAtExecute; - } - return (Flags & ~Mask) == FlagT::None; - } - - // When no descriptor flag is set, any data flag is allowed. - FlagT Mask = FlagT::None; - if (!IsSampler) { - Mask |= FlagT::DataVolatile; - Mask |= FlagT::DataStaticWhileSetAtExecute; - Mask |= FlagT::DataStatic; - } - return (Flags & ~Mask) == FlagT::None; -} - -static bool verifySamplerFilter(uint32_t Value) { - switch (Value) { -#define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - -// Values allowed here: -// https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax -static bool verifyAddress(uint32_t Address) { - switch (Address) { -#define TEXTURE_ADDRESS_MODE(Num, Val) \ - case llvm::to_underlying(dxbc::TextureAddressMode::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - -static bool verifyMipLODBias(float MipLODBias) { - return MipLODBias >= -16.f && MipLODBias <= 15.99f; -} - -static bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) { - return MaxAnisotropy <= 16u; -} - -static bool verifyComparisonFunc(uint32_t ComparisonFunc) { - switch (ComparisonFunc) { -#define COMPARISON_FUNC(Num, Val) \ - case llvm::to_underlying(dxbc::ComparisonFunc::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - -static bool verifyBorderColor(uint32_t BorderColor) { - switch (BorderColor) { -#define STATIC_BORDER_COLOR(Num, Val) \ - case llvm::to_underlying(dxbc::StaticBorderColor::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - -static bool verifyLOD(float LOD) { return !std::isnan(LOD); } - static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { - if (!verifyVersion(RSD.Version)) { + if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { return reportValueError(Ctx, "Version", RSD.Version); } - if (!verifyRootFlag(RSD.Flags)) { + if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { return reportValueError(Ctx, "RootFlags", RSD.Flags); } @@ -567,15 +424,15 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { case llvm::to_underlying(dxbc::RootParameterType::SRV): { const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); - if (!verifyRegisterValue(Descriptor.ShaderRegister)) + if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) return reportValueError(Ctx, "ShaderRegister", Descriptor.ShaderRegister); - if (!verifyRegisterSpace(Descriptor.RegisterSpace)) + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); if (RSD.Version > 1) { - if (!verifyDescriptorFlag(Descriptor.Flags)) + if (!llvm::hlsl::rootsig::verifyDescriptorFlag(Descriptor.Flags)) return reportValueError(Ctx, "DescriptorRangeFlag", Descriptor.Flags); } break; @@ -584,14 +441,14 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { const mcdxbc::DescriptorTable &Table = RSD.ParametersContainer.getDescriptorTable(Info.Location); for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!verifyRangeType(Range.RangeType)) + if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) return reportValueError(Ctx, "RangeType", Range.RangeType); - if (!verifyRegisterSpace(Range.RegisterSpace)) + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); - if (!verifyDescriptorRangeFlag(RSD.Version, Range.RangeType, - Range.Flags)) + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, Range.Flags)) return reportValueError(Ctx, "DescriptorFlag", Range.Flags); } break; @@ -600,40 +457,40 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { } for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!verifySamplerFilter(Sampler.Filter)) + if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) return reportValueError(Ctx, "Filter", Sampler.Filter); - if (!verifyAddress(Sampler.AddressU)) + if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) return reportValueError(Ctx, "AddressU", Sampler.AddressU); - if (!verifyAddress(Sampler.AddressV)) + if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) return reportValueError(Ctx, "AddressV", Sampler.AddressV); - if (!verifyAddress(Sampler.AddressW)) + if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) return reportValueError(Ctx, "AddressW", Sampler.AddressW); - if (!verifyMipLODBias(Sampler.MipLODBias)) + if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); - if (!verifyMaxAnisotropy(Sampler.MaxAnisotropy)) + if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); - if (!verifyComparisonFunc(Sampler.ComparisonFunc)) + if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); - if (!verifyBorderColor(Sampler.BorderColor)) + if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); - if (!verifyLOD(Sampler.MinLOD)) + if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); - if (!verifyLOD(Sampler.MaxLOD)) + if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); - if (!verifyRegisterValue(Sampler.ShaderRegister)) + if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); - if (!verifyRegisterSpace(Sampler.RegisterSpace)) + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))