diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index facd137e9d9dd..c04380667a640 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -228,6 +228,16 @@ enum class SamplerFilter : uint32_t { #include "DXContainerConstants.def" }; +#define FILTER(Val, Enum) \ + case Val: \ + return true; +inline bool isValidSamplerFilter(uint32_t V) { + switch (V) { +#include "DXContainerConstants.def" + } + return false; +} + LLVM_ABI ArrayRef> getSamplerFilters(); #define TEXTURE_ADDRESS_MODE(Val, Enum) Enum = Val, @@ -237,6 +247,16 @@ enum class TextureAddressMode : uint32_t { LLVM_ABI ArrayRef> getTextureAddressModes(); +#define TEXTURE_ADDRESS_MODE(Val, Enum) \ + case Val: \ + return true; +inline bool isValidAddress(uint32_t V) { + switch (V) { +#include "DXContainerConstants.def" + } + return false; +} + #define COMPARISON_FUNC(Val, Enum) Enum = Val, enum class ComparisonFunc : uint32_t { #include "DXContainerConstants.def" @@ -244,11 +264,31 @@ enum class ComparisonFunc : uint32_t { LLVM_ABI ArrayRef> getComparisonFuncs(); +#define COMPARISON_FUNC(Val, Enum) \ + case Val: \ + return true; +inline bool isValidComparisonFunc(uint32_t V) { + switch (V) { +#include "DXContainerConstants.def" + } + return false; +} + #define STATIC_BORDER_COLOR(Val, Enum) Enum = Val, enum class StaticBorderColor : uint32_t { #include "DXContainerConstants.def" }; +#define STATIC_BORDER_COLOR(Val, Enum) \ + case Val: \ + return true; +inline bool isValidBorderColor(uint32_t V) { + switch (V) { +#include "DXContainerConstants.def" + } + return false; +} + LLVM_ABI ArrayRef> getStaticBorderColors(); LLVM_ABI PartType parsePartType(StringRef S); diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 24e851933949f..ea96094b18300 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -34,12 +34,8 @@ LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, dxbc::DescriptorRangeFlags FlagsVal); LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors); -LLVM_ABI bool verifySamplerFilter(uint32_t Value); -LLVM_ABI bool verifyAddress(uint32_t Address); LLVM_ABI bool verifyMipLODBias(float MipLODBias); LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy); -LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc); -LLVM_ABI bool verifyBorderColor(uint32_t BorderColor); LLVM_ABI bool verifyLOD(float LOD); LLVM_ABI bool verifyBoundOffset(uint32_t Offset); diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h index f2722fd37a4f1..54677ef70244f 100644 --- a/llvm/include/llvm/MC/DXContainerRootSignature.h +++ b/llvm/include/llvm/MC/DXContainerRootSignature.h @@ -60,6 +60,22 @@ struct DescriptorTable { } }; +struct StaticSampler { + dxbc::SamplerFilter Filter; + dxbc::TextureAddressMode AddressU; + dxbc::TextureAddressMode AddressV; + dxbc::TextureAddressMode AddressW; + float MipLODBias; + uint32_t MaxAnisotropy; + dxbc::ComparisonFunc ComparisonFunc; + dxbc::StaticBorderColor BorderColor; + float MinLOD; + float MaxLOD; + uint32_t ShaderRegister; + uint32_t RegisterSpace; + dxbc::ShaderVisibility ShaderVisibility; +}; + struct RootParametersContainer { SmallVector ParametersInfo; @@ -125,7 +141,7 @@ struct RootSignatureDesc { uint32_t StaticSamplersOffset = 0u; uint32_t NumStaticSamplers = 0u; mcdxbc::RootParametersContainer ParametersContainer; - SmallVector StaticSamplers; + SmallVector StaticSamplers; LLVM_ABI void write(raw_ostream &OS) const; diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 31605e3900341..f29f2c7602fc6 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -52,13 +52,15 @@ static std::optional extractMdStringValue(MDNode *Node, return NodeText->getString(); } -static Expected -extractShaderVisibility(MDNode *Node, unsigned int OpId) { +template && + std::is_same_v, uint32_t>>> +Expected extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText, + llvm::function_ref VerifyFn) { if (std::optional Val = extractMdIntValue(Node, OpId)) { - if (!dxbc::isValidShaderVisibility(*Val)) - return make_error>( - "ShaderVisibility", *Val); - return dxbc::ShaderVisibility(*Val); + if (!VerifyFn(*Val)) + return make_error>(ErrText, *Val); + return static_cast(*Val); } return make_error("ShaderVisibility"); } @@ -233,7 +235,9 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD, return make_error("RootConstants Element"); Expected Visibility = - extractShaderVisibility(RootConstantNode, 1); + extractEnumValue(RootConstantNode, 1, + "ShaderVisibility", + dxbc::isValidShaderVisibility); if (auto E = Visibility.takeError()) return Error(std::move(E)); @@ -287,7 +291,9 @@ Error MetadataParser::parseRootDescriptors( } Expected Visibility = - extractShaderVisibility(RootDescriptorNode, 1); + extractEnumValue(RootDescriptorNode, 1, + "ShaderVisibility", + dxbc::isValidShaderVisibility); if (auto E = Visibility.takeError()) return Error(std::move(E)); @@ -380,7 +386,9 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD, return make_error("Descriptor Table"); Expected Visibility = - extractShaderVisibility(DescriptorTableNode, 1); + extractEnumValue(DescriptorTableNode, 1, + "ShaderVisibility", + dxbc::isValidShaderVisibility); if (auto E = Visibility.takeError()) return Error(std::move(E)); @@ -406,26 +414,34 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, if (StaticSamplerNode->getNumOperands() != 14) return make_error("Static Sampler"); - dxbc::RTS0::v1::StaticSampler Sampler; - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 1)) - Sampler.Filter = *Val; - else - return make_error("Filter"); + mcdxbc::StaticSampler Sampler; - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 2)) - Sampler.AddressU = *Val; - else - return make_error("AddressU"); + Expected Filter = extractEnumValue( + StaticSamplerNode, 1, "Filter", dxbc::isValidSamplerFilter); + if (auto E = Filter.takeError()) + return Error(std::move(E)); + Sampler.Filter = *Filter; - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 3)) - Sampler.AddressV = *Val; - else - return make_error("AddressV"); + Expected AddressU = + extractEnumValue( + StaticSamplerNode, 2, "AddressU", dxbc::isValidAddress); + if (auto E = AddressU.takeError()) + return Error(std::move(E)); + Sampler.AddressU = *AddressU; - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 4)) - Sampler.AddressW = *Val; - else - return make_error("AddressW"); + Expected AddressV = + extractEnumValue( + StaticSamplerNode, 3, "AddressV", dxbc::isValidAddress); + if (auto E = AddressV.takeError()) + return Error(std::move(E)); + Sampler.AddressV = *AddressV; + + Expected AddressW = + extractEnumValue( + StaticSamplerNode, 4, "AddressW", dxbc::isValidAddress); + if (auto E = AddressW.takeError()) + return Error(std::move(E)); + Sampler.AddressW = *AddressW; if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = *Val; @@ -437,15 +453,19 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, else return make_error("MaxAnisotropy"); - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 7)) - Sampler.ComparisonFunc = *Val; - else - return make_error("ComparisonFunc"); + Expected ComparisonFunc = + extractEnumValue( + StaticSamplerNode, 7, "ComparisonFunc", dxbc::isValidComparisonFunc); + if (auto E = ComparisonFunc.takeError()) + return Error(std::move(E)); + Sampler.ComparisonFunc = *ComparisonFunc; - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 8)) - Sampler.BorderColor = *Val; - else - return make_error("ComparisonFunc"); + Expected BorderColor = + extractEnumValue( + StaticSamplerNode, 8, "BorderColor", dxbc::isValidBorderColor); + if (auto E = BorderColor.takeError()) + return Error(std::move(E)); + Sampler.BorderColor = *BorderColor; if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = *Val; @@ -467,10 +487,13 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD, else return make_error("RegisterSpace"); - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 13)) - Sampler.ShaderVisibility = *Val; - else - return make_error("ShaderVisibility"); + Expected Visibility = + extractEnumValue(StaticSamplerNode, 13, + "ShaderVisibility", + dxbc::isValidShaderVisibility); + if (auto E = Visibility.takeError()) + return Error(std::move(E)); + Sampler.ShaderVisibility = *Visibility; RSD.StaticSamplers.push_back(Sampler); return Error::success(); @@ -594,30 +617,7 @@ Error MetadataParser::validateRootSignature( } } - for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "Filter", Sampler.Filter)); - - if (!hlsl::rootsig::verifyAddress(Sampler.AddressU)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "AddressU", Sampler.AddressU)); - - if (!hlsl::rootsig::verifyAddress(Sampler.AddressV)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "AddressV", Sampler.AddressV)); - - if (!hlsl::rootsig::verifyAddress(Sampler.AddressW)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "AddressW", Sampler.AddressW)); + for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) { if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) DeferredErrs = joinErrors(std::move(DeferredErrs), @@ -630,18 +630,6 @@ Error MetadataParser::validateRootSignature( make_error>( "MaxAnisotropy", Sampler.MaxAnisotropy)); - if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "ComparisonFunc", Sampler.ComparisonFunc)); - - if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "BorderColor", Sampler.BorderColor)); - if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD)) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error>( @@ -663,12 +651,6 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error>( "RegisterSpace", Sampler.RegisterSpace)); - - if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) - DeferredErrs = - joinErrors(std::move(DeferredErrs), - make_error>( - "ShaderVisibility", Sampler.ShaderVisibility)); } return DeferredErrs; diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index d682dda0bab26..0970977b5064f 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -115,27 +115,6 @@ bool verifyNumDescriptors(uint32_t NumDescriptors) { return NumDescriptors > 0; } -bool verifySamplerFilter(uint32_t Value) { - switch (Value) { -#define FILTER(Num, Val) case llvm::to_underlying(dxbc::SamplerFilter::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - -// Values allowed here: -// https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax -bool verifyAddress(uint32_t Address) { - switch (Address) { -#define TEXTURE_ADDRESS_MODE(Num, Val) \ - case llvm::to_underlying(dxbc::TextureAddressMode::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - bool verifyMipLODBias(float MipLODBias) { return MipLODBias >= -16.f && MipLODBias <= 15.99f; } @@ -144,26 +123,6 @@ bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) { return MaxAnisotropy <= 16u; } -bool verifyComparisonFunc(uint32_t ComparisonFunc) { - switch (ComparisonFunc) { -#define COMPARISON_FUNC(Num, Val) \ - case llvm::to_underlying(dxbc::ComparisonFunc::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - -bool verifyBorderColor(uint32_t BorderColor) { - switch (BorderColor) { -#define STATIC_BORDER_COLOR(Num, Val) \ - case llvm::to_underlying(dxbc::StaticBorderColor::Val): -#include "llvm/BinaryFormat/DXContainerConstants.def" - return true; - } - return false; -} - bool verifyLOD(float LOD) { return !std::isnan(LOD); } bool verifyBoundOffset(uint32_t Offset) { diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp index 1078b1188bb66..73dfa9899d613 100644 --- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp +++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp @@ -335,20 +335,30 @@ Error DXContainerWriter::writeParts(raw_ostream &OS) { } for (const auto &Param : P.RootSignature->samplers()) { - dxbc::RTS0::v1::StaticSampler NewSampler; - NewSampler.Filter = Param.Filter; - NewSampler.AddressU = Param.AddressU; - NewSampler.AddressV = Param.AddressV; - NewSampler.AddressW = Param.AddressW; + assert(dxbc::isValidSamplerFilter(Param.Filter) && + dxbc::isValidAddress(Param.AddressU) && + dxbc::isValidAddress(Param.AddressV) && + dxbc::isValidAddress(Param.AddressW) && + dxbc::isValidComparisonFunc(Param.ComparisonFunc) && + dxbc::isValidBorderColor(Param.BorderColor) && + dxbc::isValidShaderVisibility(Param.ShaderVisibility) && + "Invalid enum value in static sampler"); + + mcdxbc::StaticSampler NewSampler; + NewSampler.Filter = dxbc::SamplerFilter(Param.Filter); + NewSampler.AddressU = dxbc::TextureAddressMode(Param.AddressU); + NewSampler.AddressV = dxbc::TextureAddressMode(Param.AddressV); + NewSampler.AddressW = dxbc::TextureAddressMode(Param.AddressW); NewSampler.MipLODBias = Param.MipLODBias; NewSampler.MaxAnisotropy = Param.MaxAnisotropy; - NewSampler.ComparisonFunc = Param.ComparisonFunc; - NewSampler.BorderColor = Param.BorderColor; + NewSampler.ComparisonFunc = dxbc::ComparisonFunc(Param.ComparisonFunc); + NewSampler.BorderColor = dxbc::StaticBorderColor(Param.BorderColor); NewSampler.MinLOD = Param.MinLOD; NewSampler.MaxLOD = Param.MaxLOD; NewSampler.ShaderRegister = Param.ShaderRegister; NewSampler.RegisterSpace = Param.RegisterSpace; - NewSampler.ShaderVisibility = Param.ShaderVisibility; + NewSampler.ShaderVisibility = + dxbc::ShaderVisibility(Param.ShaderVisibility); RS.StaticSamplers.push_back(NewSampler); } diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index d02f4b9f7ebcd..5bef33e546e68 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -203,7 +203,7 @@ static void validateRootSignature(Module &M, } } - for (const dxbc::RTS0::v1::StaticSampler &S : RSD.StaticSamplers) + for (const mcdxbc::StaticSampler &S : RSD.StaticSamplers) Builder.trackBinding(dxil::ResourceClass::Sampler, S.RegisterSpace, S.ShaderRegister, S.ShaderRegister, &S); diff --git a/llvm/test/ObjectYAML/DXContainer/RootSignature-StaticSamplers.yaml b/llvm/test/ObjectYAML/DXContainer/RootSignature-StaticSamplers.yaml index 82d9a4ffdb4f8..888a32b351690 100644 --- a/llvm/test/ObjectYAML/DXContainer/RootSignature-StaticSamplers.yaml +++ b/llvm/test/ObjectYAML/DXContainer/RootSignature-StaticSamplers.yaml @@ -20,7 +20,7 @@ Parts: StaticSamplersOffset: 24 Parameters: [] Samplers: - - Filter: 10 + - Filter: 16 AddressU: 1 AddressV: 2 AddressW: 5 @@ -46,7 +46,7 @@ Parts: #CHECK-NEXT: StaticSamplersOffset: 24 #CHECK-NEXT: Parameters: [] #CHECK-NEXT: Samplers: -#CHECK-NEXT: - Filter: 10 +#CHECK-NEXT: - Filter: 16 #CHECK-NEXT: AddressU: 1 #CHECK-NEXT: AddressV: 2 #CHECK-NEXT: AddressW: 5 diff --git a/llvm/unittests/ObjectYAML/DXContainerYAMLTest.cpp b/llvm/unittests/ObjectYAML/DXContainerYAMLTest.cpp index 4cf8f61e83c8d..a264ca7c3c3f6 100644 --- a/llvm/unittests/ObjectYAML/DXContainerYAMLTest.cpp +++ b/llvm/unittests/ObjectYAML/DXContainerYAMLTest.cpp @@ -492,7 +492,7 @@ TEST(RootSignature, ParseStaticSamplers) { StaticSamplersOffset: 24 Parameters: [] Samplers: - - Filter: 10 + - Filter: 16 AddressU: 1 AddressV: 2 AddressW: 5 @@ -517,7 +517,7 @@ TEST(RootSignature, ParseStaticSamplers) { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x52, 0x54, 0x53, 0x30, 0x4c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0xa4, 0x70, 0x9d, 0x3f, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x85, 0xeb, 0x91, 0x40, 0x66, 0x66, 0x0e, 0x41,