2727#include "llvm/Support/Error.h"
2828#include "llvm/Support/ErrorHandling.h"
2929#include "llvm/Support/raw_ostream.h"
30+ #include <cmath>
3031#include <cstdint>
3132#include <optional>
3233#include <utility>
@@ -55,6 +56,13 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
5556 return std::nullopt;
5657}
5758
59+ static std::optional<float> extractMdFloatValue(MDNode *Node,
60+ unsigned int OpId) {
61+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
62+ return CI->getValueAPF().convertToFloat();
63+ return std::nullopt;
64+ }
65+
5866static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5967 unsigned int OpId) {
6068 MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
@@ -261,6 +269,81 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
261269 return false;
262270}
263271
272+ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
273+ MDNode *StaticSamplerNode) {
274+ if (StaticSamplerNode->getNumOperands() != 14)
275+ return reportError(Ctx, "Invalid format for Static Sampler");
276+
277+ dxbc::RTS0::v1::StaticSampler Sampler;
278+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
279+ Sampler.Filter = *Val;
280+ else
281+ return reportError(Ctx, "Invalid value for Filter");
282+
283+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
284+ Sampler.AddressU = *Val;
285+ else
286+ return reportError(Ctx, "Invalid value for AddressU");
287+
288+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
289+ Sampler.AddressV = *Val;
290+ else
291+ return reportError(Ctx, "Invalid value for AddressV");
292+
293+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
294+ Sampler.AddressW = *Val;
295+ else
296+ return reportError(Ctx, "Invalid value for AddressW");
297+
298+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
299+ Sampler.MipLODBias = *Val;
300+ else
301+ return reportError(Ctx, "Invalid value for MipLODBias");
302+
303+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
304+ Sampler.MaxAnisotropy = *Val;
305+ else
306+ return reportError(Ctx, "Invalid value for MaxAnisotropy");
307+
308+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
309+ Sampler.ComparisonFunc = *Val;
310+ else
311+ return reportError(Ctx, "Invalid value for ComparisonFunc ");
312+
313+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
314+ Sampler.BorderColor = *Val;
315+ else
316+ return reportError(Ctx, "Invalid value for ComparisonFunc ");
317+
318+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
319+ Sampler.MinLOD = *Val;
320+ else
321+ return reportError(Ctx, "Invalid value for MinLOD");
322+
323+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
324+ Sampler.MaxLOD = *Val;
325+ else
326+ return reportError(Ctx, "Invalid value for MaxLOD");
327+
328+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
329+ Sampler.ShaderRegister = *Val;
330+ else
331+ return reportError(Ctx, "Invalid value for ShaderRegister");
332+
333+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
334+ Sampler.RegisterSpace = *Val;
335+ else
336+ return reportError(Ctx, "Invalid value for RegisterSpace");
337+
338+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
339+ Sampler.ShaderVisibility = *Val;
340+ else
341+ return reportError(Ctx, "Invalid value for ShaderVisibility");
342+
343+ RSD.StaticSamplers.push_back(Sampler);
344+ return false;
345+ }
346+
264347static bool parseRootSignatureElement(LLVMContext *Ctx,
265348 mcdxbc::RootSignatureDesc &RSD,
266349 MDNode *Element) {
@@ -276,6 +359,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
276359 .Case("RootSRV", RootSignatureElementKind::SRV)
277360 .Case("RootUAV", RootSignatureElementKind::UAV)
278361 .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
362+ .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
279363 .Default(RootSignatureElementKind::Error);
280364
281365 switch (ElementKind) {
@@ -290,6 +374,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
290374 return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
291375 case RootSignatureElementKind::DescriptorTable:
292376 return parseDescriptorTable(Ctx, RSD, Element);
377+ case RootSignatureElementKind::StaticSamplers:
378+ return parseStaticSampler(Ctx, RSD, Element);
293379 case RootSignatureElementKind::Error:
294380 return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
295381 }
@@ -406,6 +492,58 @@ static bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
406492 return (Flags & ~Mask) == FlagT::NONE;
407493}
408494
495+ static bool verifySamplerFilter(uint32_t Value) {
496+ switch (Value) {
497+ #define STATIC_SAMPLER_FILTER(Num, Val) \
498+ case llvm::to_underlying(dxbc::StaticSamplerFilter::Val):
499+ #include "llvm/BinaryFormat/DXContainerConstants.def"
500+ return true;
501+ }
502+ return false;
503+ }
504+
505+ // Values allowed here:
506+ // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
507+ static bool verifyAddress(uint32_t Address) {
508+ switch (Address) {
509+ #define TEXTURE_ADDRESS_MODE(Num, Val) \
510+ case llvm::to_underlying(dxbc::TextureAddressMode::Val):
511+ #include "llvm/BinaryFormat/DXContainerConstants.def"
512+ return true;
513+ }
514+ return false;
515+ }
516+
517+ static bool verifyMipLODBias(float MipLODBias) {
518+ return MipLODBias >= -16.f && MipLODBias <= 15.99f;
519+ }
520+
521+ static bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
522+ return MaxAnisotropy <= 16u;
523+ }
524+
525+ static bool verifyComparisonFunc(uint32_t ComparisonFunc) {
526+ switch (ComparisonFunc) {
527+ #define COMPARISON_FUNCTION(Num, Val) \
528+ case llvm::to_underlying(dxbc::SamplersComparisonFunction::Val):
529+ #include "llvm/BinaryFormat/DXContainerConstants.def"
530+ return true;
531+ }
532+ return false;
533+ }
534+
535+ static bool verifyBorderColor(uint32_t BorderColor) {
536+ switch (BorderColor) {
537+ #define STATIC_BORDER_COLOR(Num, Val) \
538+ case llvm::to_underlying(dxbc::SamplersBorderColor::Val):
539+ #include "llvm/BinaryFormat/DXContainerConstants.def"
540+ return true;
541+ }
542+ return false;
543+ }
544+
545+ static bool verifyLOD(float LOD) { return !std::isnan(LOD); }
546+
409547static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
410548
411549 if (!verifyVersion(RSD.Version)) {
@@ -463,6 +601,48 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
463601 }
464602 }
465603
604+ for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
605+ if (!verifySamplerFilter(Sampler.Filter))
606+ return reportValueError(Ctx, "Filter", Sampler.Filter);
607+
608+ if (!verifyAddress(Sampler.AddressU))
609+ return reportValueError(Ctx, "AddressU", Sampler.AddressU);
610+
611+ if (!verifyAddress(Sampler.AddressV))
612+ return reportValueError(Ctx, "AddressV", Sampler.AddressV);
613+
614+ if (!verifyAddress(Sampler.AddressW))
615+ return reportValueError(Ctx, "AddressW", Sampler.AddressW);
616+
617+ if (!verifyMipLODBias(Sampler.MipLODBias))
618+ return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
619+
620+ if (!verifyMaxAnisotropy(Sampler.MaxAnisotropy))
621+ return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
622+
623+ if (!verifyComparisonFunc(Sampler.ComparisonFunc))
624+ return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
625+
626+ if (!verifyBorderColor(Sampler.BorderColor))
627+ return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
628+
629+ if (!verifyLOD(Sampler.MinLOD))
630+ return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
631+
632+ if (!verifyLOD(Sampler.MaxLOD))
633+ return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
634+
635+ if (!verifyRegisterValue(Sampler.ShaderRegister))
636+ return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
637+
638+ if (!verifyRegisterSpace(Sampler.RegisterSpace))
639+ return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
640+
641+ if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
642+ return reportValueError(Ctx, "ShaderVisibility",
643+ Sampler.ShaderVisibility);
644+ }
645+
466646 return false;
467647}
468648
@@ -542,6 +722,9 @@ analyzeModule(Module &M) {
542722 // offset will always equal to the header size.
543723 RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
544724
725+ // static sampler offset is calculated when writting dxcontainer.
726+ RSD.StaticSamplersOffset = 0u;
727+
545728 if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
546729 return RSDMap;
547730 }
0 commit comments