1212// ===----------------------------------------------------------------------===//
1313#include " DXILRootSignature.h"
1414#include " DirectX.h"
15+ #include " llvm/ADT/APFloat.h"
1516#include " llvm/ADT/StringSwitch.h"
1617#include " llvm/ADT/Twine.h"
1718#include " llvm/Analysis/DXILMetadataAnalysis.h"
@@ -55,6 +56,13 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
5556 return std::nullopt ;
5657}
5758
59+ static std::optional<APFloat> extractMdFloatValue (MDNode *Node,
60+ unsigned int OpId) {
61+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand (OpId).get ()))
62+ return CI->getValue ();
63+ return std::nullopt ;
64+ }
65+
5866static std::optional<StringRef> extractMdStringValue (MDNode *Node,
5967 unsigned int OpId) {
6068 MDString *NodeText = dyn_cast<MDString>(Node->getOperand (OpId));
@@ -262,6 +270,81 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
262270 return false ;
263271}
264272
273+ static bool parseStaticSampler (LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
274+ MDNode *StaticSamplerNode) {
275+ if (StaticSamplerNode->getNumOperands () != 14 )
276+ return reportError (Ctx, " Invalid format for Static Sampler" );
277+
278+ dxbc::RTS0::v1::StaticSampler Sampler;
279+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 1 ))
280+ Sampler.Filter = *Val;
281+ else
282+ return reportError (Ctx, " Invalid value for Filter" );
283+
284+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 2 ))
285+ Sampler.AddressU = *Val;
286+ else
287+ return reportError (Ctx, " Invalid value for AddressU" );
288+
289+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 3 ))
290+ Sampler.AddressV = *Val;
291+ else
292+ return reportError (Ctx, " Invalid value for AddressV" );
293+
294+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 4 ))
295+ Sampler.AddressW = *Val;
296+ else
297+ return reportError (Ctx, " Invalid value for AddressW" );
298+
299+ if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 5 ))
300+ Sampler.MipLODBias = Val->convertToFloat ();
301+ else
302+ return reportError (Ctx, " Invalid value for MipLODBias" );
303+
304+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 6 ))
305+ Sampler.MaxAnisotropy = *Val;
306+ else
307+ return reportError (Ctx, " Invalid value for MaxAnisotropy" );
308+
309+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 7 ))
310+ Sampler.ComparisonFunc = *Val;
311+ else
312+ return reportError (Ctx, " Invalid value for ComparisonFunc " );
313+
314+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 8 ))
315+ Sampler.BorderColor = *Val;
316+ else
317+ return reportError (Ctx, " Invalid value for ComparisonFunc " );
318+
319+ if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 9 ))
320+ Sampler.MinLOD = Val->convertToFloat ();
321+ else
322+ return reportError (Ctx, " Invalid value for MinLOD" );
323+
324+ if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 10 ))
325+ Sampler.MaxLOD = Val->convertToFloat ();
326+ else
327+ return reportError (Ctx, " Invalid value for MaxLOD" );
328+
329+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 11 ))
330+ Sampler.ShaderRegister = *Val;
331+ else
332+ return reportError (Ctx, " Invalid value for ShaderRegister" );
333+
334+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 12 ))
335+ Sampler.RegisterSpace = *Val;
336+ else
337+ return reportError (Ctx, " Invalid value for RegisterSpace" );
338+
339+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 13 ))
340+ Sampler.ShaderVisibility = *Val;
341+ else
342+ return reportError (Ctx, " Invalid value for ShaderVisibility" );
343+
344+ RSD.StaticSamplers .push_back (Sampler);
345+ return false ;
346+ }
347+
265348static bool parseRootSignatureElement (LLVMContext *Ctx,
266349 mcdxbc::RootSignatureDesc &RSD,
267350 MDNode *Element) {
@@ -277,6 +360,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
277360 .Case (" RootSRV" , RootSignatureElementKind::SRV)
278361 .Case (" RootUAV" , RootSignatureElementKind::UAV)
279362 .Case (" DescriptorTable" , RootSignatureElementKind::DescriptorTable)
363+ .Case (" StaticSampler" , RootSignatureElementKind::StaticSamplers)
280364 .Default (RootSignatureElementKind::Error);
281365
282366 switch (ElementKind) {
@@ -291,6 +375,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
291375 return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
292376 case RootSignatureElementKind::DescriptorTable:
293377 return parseDescriptorTable (Ctx, RSD, Element);
378+ case RootSignatureElementKind::StaticSamplers:
379+ return parseStaticSampler (Ctx, RSD, Element);
294380 case RootSignatureElementKind::Error:
295381 return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
296382 }
@@ -522,6 +608,9 @@ analyzeModule(Module &M) {
522608 // offset will always equal to the header size.
523609 RSD.RootParameterOffset = sizeof (dxbc::RTS0::v1::RootSignatureHeader);
524610
611+ // static sampler offset is calculated when writting dxcontainer.
612+ RSD.StaticSamplersOffset = 0u ;
613+
525614 if (parse (Ctx, RSD, RootElementListNode) || validate (Ctx, RSD)) {
526615 return RSDMap;
527616 }
0 commit comments