12
12
// ===----------------------------------------------------------------------===//
13
13
#include " DXILRootSignature.h"
14
14
#include " DirectX.h"
15
+ #include " llvm/ADT/APFloat.h"
15
16
#include " llvm/ADT/StringSwitch.h"
16
17
#include " llvm/ADT/Twine.h"
17
18
#include " llvm/Analysis/DXILMetadataAnalysis.h"
@@ -55,6 +56,13 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
55
56
return std::nullopt;
56
57
}
57
58
59
+ static std::optional<APFloat> extractMdFloatValue (MDNode *Node,
60
+ unsigned int OpId) {
61
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand (OpId).get ()))
62
+ return CI->getValue ();
63
+ return std::nullopt;
64
+ }
65
+
58
66
static std::optional<StringRef> extractMdStringValue (MDNode *Node,
59
67
unsigned int OpId) {
60
68
MDString *NodeText = dyn_cast<MDString>(Node->getOperand (OpId));
@@ -262,6 +270,81 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
262
270
return false ;
263
271
}
264
272
273
+ static bool parseStaticSampler (LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
274
+ MDNode *StaticSamplerNode) {
275
+ if (StaticSamplerNode->getNumOperands () != 14 )
276
+ return reportError (Ctx, " Invalid format for Static Sampler" );
277
+
278
+ dxbc::RTS0::v1::StaticSampler Sampler;
279
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 1 ))
280
+ Sampler.Filter = *Val;
281
+ else
282
+ return reportError (Ctx, " Invalid value for Filter" );
283
+
284
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 2 ))
285
+ Sampler.AddressU = *Val;
286
+ else
287
+ return reportError (Ctx, " Invalid value for AddressU" );
288
+
289
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 3 ))
290
+ Sampler.AddressV = *Val;
291
+ else
292
+ return reportError (Ctx, " Invalid value for AddressV" );
293
+
294
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 4 ))
295
+ Sampler.AddressW = *Val;
296
+ else
297
+ return reportError (Ctx, " Invalid value for AddressW" );
298
+
299
+ if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 5 ))
300
+ Sampler.MipLODBias = Val->convertToFloat ();
301
+ else
302
+ return reportError (Ctx, " Invalid value for MipLODBias" );
303
+
304
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 6 ))
305
+ Sampler.MaxAnisotropy = *Val;
306
+ else
307
+ return reportError (Ctx, " Invalid value for MaxAnisotropy" );
308
+
309
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 7 ))
310
+ Sampler.ComparisonFunc = *Val;
311
+ else
312
+ return reportError (Ctx, " Invalid value for ComparisonFunc " );
313
+
314
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 8 ))
315
+ Sampler.BorderColor = *Val;
316
+ else
317
+ return reportError (Ctx, " Invalid value for ComparisonFunc " );
318
+
319
+ if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 9 ))
320
+ Sampler.MinLOD = Val->convertToFloat ();
321
+ else
322
+ return reportError (Ctx, " Invalid value for MinLOD" );
323
+
324
+ if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 10 ))
325
+ Sampler.MaxLOD = Val->convertToFloat ();
326
+ else
327
+ return reportError (Ctx, " Invalid value for MaxLOD" );
328
+
329
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 11 ))
330
+ Sampler.ShaderRegister = *Val;
331
+ else
332
+ return reportError (Ctx, " Invalid value for ShaderRegister" );
333
+
334
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 12 ))
335
+ Sampler.RegisterSpace = *Val;
336
+ else
337
+ return reportError (Ctx, " Invalid value for RegisterSpace" );
338
+
339
+ if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 13 ))
340
+ Sampler.ShaderVisibility = *Val;
341
+ else
342
+ return reportError (Ctx, " Invalid value for ShaderVisibility" );
343
+
344
+ RSD.StaticSamplers .push_back (Sampler);
345
+ return false ;
346
+ }
347
+
265
348
static bool parseRootSignatureElement (LLVMContext *Ctx,
266
349
mcdxbc::RootSignatureDesc &RSD,
267
350
MDNode *Element) {
@@ -277,6 +360,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
277
360
.Case (" RootSRV" , RootSignatureElementKind::SRV)
278
361
.Case (" RootUAV" , RootSignatureElementKind::UAV)
279
362
.Case (" DescriptorTable" , RootSignatureElementKind::DescriptorTable)
363
+ .Case (" StaticSampler" , RootSignatureElementKind::StaticSamplers)
280
364
.Default (RootSignatureElementKind::Error);
281
365
282
366
switch (ElementKind) {
@@ -291,6 +375,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
291
375
return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
292
376
case RootSignatureElementKind::DescriptorTable:
293
377
return parseDescriptorTable (Ctx, RSD, Element);
378
+ case RootSignatureElementKind::StaticSamplers:
379
+ return parseStaticSampler (Ctx, RSD, Element);
294
380
case RootSignatureElementKind::Error:
295
381
return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
296
382
}
@@ -522,6 +608,9 @@ analyzeModule(Module &M) {
522
608
// offset will always equal to the header size.
523
609
RSD.RootParameterOffset = sizeof (dxbc::RTS0::v1::RootSignatureHeader);
524
610
611
+ // static sampler offset is calculated when writting dxcontainer.
612
+ RSD.StaticSamplersOffset = 0u ;
613
+
525
614
if (parse (Ctx, RSD, RootElementListNode) || validate (Ctx, RSD)) {
526
615
return RSDMap;
527
616
}
0 commit comments