1212// ===----------------------------------------------------------------------===//
1313#include " DXILRootSignature.h"
1414#include " DirectX.h"
15- #include " llvm/ADT/StringRef.h"
1615#include " llvm/ADT/StringSwitch.h"
1716#include " llvm/ADT/Twine.h"
1817#include " llvm/Analysis/DXILMetadataAnalysis.h"
3130#include < cmath>
3231#include < cstdint>
3332#include < optional>
34- #include < string>
3533#include < utility>
3634
3735using namespace llvm ;
@@ -290,32 +288,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
290288 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 1 ))
291289 Range.NumDescriptors = *Val;
292290 else
293- return reportInvalidTypeError<MDString >(Ctx, " RangeDescriptorNode" ,
294- RangeDescriptorNode, 1 );
291+ return reportInvalidTypeError<ConstantInt >(Ctx, " RangeDescriptorNode" ,
292+ RangeDescriptorNode, 1 );
295293
296294 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 2 ))
297295 Range.BaseShaderRegister = *Val;
298296 else
299- return reportInvalidTypeError<MDString >(Ctx, " RangeDescriptorNode" ,
300- RangeDescriptorNode, 2 );
297+ return reportInvalidTypeError<ConstantInt >(Ctx, " RangeDescriptorNode" ,
298+ RangeDescriptorNode, 2 );
301299
302300 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 3 ))
303301 Range.RegisterSpace = *Val;
304302 else
305- return reportInvalidTypeError<MDString >(Ctx, " RangeDescriptorNode" ,
306- RangeDescriptorNode, 3 );
303+ return reportInvalidTypeError<ConstantInt >(Ctx, " RangeDescriptorNode" ,
304+ RangeDescriptorNode, 3 );
307305
308306 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 4 ))
309307 Range.OffsetInDescriptorsFromTableStart = *Val;
310308 else
311- return reportInvalidTypeError<MDString >(Ctx, " RangeDescriptorNode" ,
312- RangeDescriptorNode, 4 );
309+ return reportInvalidTypeError<ConstantInt >(Ctx, " RangeDescriptorNode" ,
310+ RangeDescriptorNode, 4 );
313311
314312 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 5 ))
315313 Range.Flags = *Val;
316314 else
317- return reportInvalidTypeError<MDString >(Ctx, " RangeDescriptorNode" ,
318- RangeDescriptorNode, 5 );
315+ return reportInvalidTypeError<ConstantInt >(Ctx, " RangeDescriptorNode" ,
316+ RangeDescriptorNode, 5 );
319317
320318 Table.Ranges .push_back (Range);
321319 return false ;
@@ -332,8 +330,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
332330 if (std::optional<uint32_t > Val = extractMdIntValue (DescriptorTableNode, 1 ))
333331 Header.ShaderVisibility = *Val;
334332 else
335- return reportInvalidTypeError<MDString >(Ctx, " DescriptorTableNode" ,
336- DescriptorTableNode, 1 );
333+ return reportInvalidTypeError<ConstantInt >(Ctx, " DescriptorTableNode" ,
334+ DescriptorTableNode, 1 );
337335
338336 mcdxbc::DescriptorTable Table;
339337 Header.ParameterType =
@@ -362,67 +360,80 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
362360 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 1 ))
363361 Sampler.Filter = *Val;
364362 else
365- return reportError (Ctx, " Invalid value for Filter" );
363+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
364+ StaticSamplerNode, 1 );
366365
367366 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 2 ))
368367 Sampler.AddressU = *Val;
369368 else
370- return reportError (Ctx, " Invalid value for AddressU" );
369+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
370+ StaticSamplerNode, 2 );
371371
372372 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 3 ))
373373 Sampler.AddressV = *Val;
374374 else
375- return reportError (Ctx, " Invalid value for AddressV" );
375+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
376+ StaticSamplerNode, 3 );
376377
377378 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 4 ))
378379 Sampler.AddressW = *Val;
379380 else
380- return reportError (Ctx, " Invalid value for AddressW" );
381+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
382+ StaticSamplerNode, 4 );
381383
382384 if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 5 ))
383385 Sampler.MipLODBias = Val->convertToFloat ();
384386 else
385- return reportError (Ctx, " Invalid value for MipLODBias" );
387+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
388+ StaticSamplerNode, 5 );
386389
387390 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 6 ))
388391 Sampler.MaxAnisotropy = *Val;
389392 else
390- return reportError (Ctx, " Invalid value for MaxAnisotropy" );
393+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
394+ StaticSamplerNode, 6 );
391395
392396 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 7 ))
393397 Sampler.ComparisonFunc = *Val;
394398 else
395- return reportError (Ctx, " Invalid value for ComparisonFunc " );
399+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
400+ StaticSamplerNode, 7 );
396401
397402 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 8 ))
398403 Sampler.BorderColor = *Val;
399404 else
400- return reportError (Ctx, " Invalid value for ComparisonFunc " );
405+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
406+ StaticSamplerNode, 8 );
401407
402408 if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 9 ))
403409 Sampler.MinLOD = Val->convertToFloat ();
404410 else
405- return reportError (Ctx, " Invalid value for MinLOD" );
411+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
412+ StaticSamplerNode, 9 );
406413
407414 if (std::optional<APFloat> Val = extractMdFloatValue (StaticSamplerNode, 10 ))
408415 Sampler.MaxLOD = Val->convertToFloat ();
409416 else
410- return reportError (Ctx, " Invalid value for MaxLOD" );
417+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
418+ StaticSamplerNode, 10 );
411419
412420 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 11 ))
413421 Sampler.ShaderRegister = *Val;
414422 else
415- return reportError (Ctx, " Invalid value for ShaderRegister" );
423+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
424+ StaticSamplerNode, 11 );
416425
417426 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 12 ))
418427 Sampler.RegisterSpace = *Val;
419428 else
420- return reportError (Ctx, " Invalid value for RegisterSpace" );
429+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
430+ StaticSamplerNode, 12 );
421431
422432 if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 13 ))
423433 Sampler.ShaderVisibility = *Val;
424434 else
425- return reportError (Ctx, " Invalid value for ShaderVisibility" );
435+ return reportInvalidTypeError<ConstantInt>(Ctx, " StaticSamplerNode" ,
436+ StaticSamplerNode, 13 );
426437
427438 RSD.StaticSamplers .push_back (Sampler);
428439 return false ;
0 commit comments