@@ -52,13 +52,15 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5252 return NodeText->getString ();
5353}
5454
55- static Expected<dxbc::ShaderVisibility>
56- extractShaderVisibility (MDNode *Node, unsigned int OpId) {
55+ template <typename T, typename = std::enable_if_t <
56+ std::is_enum_v<T> &&
57+ std::is_same_v<std::underlying_type_t <T>, uint32_t >>>
58+ Expected<T> extractEnumValue (MDNode *Node, unsigned int OpId, StringRef ErrText,
59+ llvm::function_ref<bool (uint32_t )> VerifyFn) {
5760 if (std::optional<uint32_t > Val = extractMdIntValue (Node, OpId)) {
58- if (!dxbc::isValidShaderVisibility (*Val))
59- return make_error<RootSignatureValidationError<uint32_t >>(
60- " ShaderVisibility" , *Val);
61- return dxbc::ShaderVisibility (*Val);
61+ if (!VerifyFn (*Val))
62+ return make_error<RootSignatureValidationError<uint32_t >>(ErrText, *Val);
63+ return static_cast <T>(*Val);
6264 }
6365 return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
6466}
@@ -233,7 +235,9 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
233235 return make_error<InvalidRSMetadataFormat>(" RootConstants Element" );
234236
235237 Expected<dxbc::ShaderVisibility> Visibility =
236- extractShaderVisibility (RootConstantNode, 1 );
238+ extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1 ,
239+ " ShaderVisibility" ,
240+ dxbc::isValidShaderVisibility);
237241 if (auto E = Visibility.takeError ())
238242 return Error (std::move (E));
239243
@@ -287,7 +291,9 @@ Error MetadataParser::parseRootDescriptors(
287291 }
288292
289293 Expected<dxbc::ShaderVisibility> Visibility =
290- extractShaderVisibility (RootDescriptorNode, 1 );
294+ extractEnumValue<dxbc::ShaderVisibility>(RootDescriptorNode, 1 ,
295+ " ShaderVisibility" ,
296+ dxbc::isValidShaderVisibility);
291297 if (auto E = Visibility.takeError ())
292298 return Error (std::move (E));
293299
@@ -380,7 +386,9 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
380386 return make_error<InvalidRSMetadataFormat>(" Descriptor Table" );
381387
382388 Expected<dxbc::ShaderVisibility> Visibility =
383- extractShaderVisibility (DescriptorTableNode, 1 );
389+ extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1 ,
390+ " ShaderVisibility" ,
391+ dxbc::isValidShaderVisibility);
384392 if (auto E = Visibility.takeError ())
385393 return Error (std::move (E));
386394
@@ -406,26 +414,34 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
406414 if (StaticSamplerNode->getNumOperands () != 14 )
407415 return make_error<InvalidRSMetadataFormat>(" Static Sampler" );
408416
409- dxbc::RTS0::v1::StaticSampler Sampler;
410- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 1 ))
411- Sampler.Filter = *Val;
412- else
413- return make_error<InvalidRSMetadataValue>(" Filter" );
417+ mcdxbc::StaticSampler Sampler;
414418
415- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 2 ))
416- Sampler.AddressU = *Val;
417- else
418- return make_error<InvalidRSMetadataValue>(" AddressU" );
419+ Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>(
420+ StaticSamplerNode, 1 , " Filter" , dxbc::isValidSamplerFilter);
421+ if (auto E = Filter.takeError ())
422+ return Error (std::move (E));
423+ Sampler.Filter = *Filter;
419424
420- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 3 ))
421- Sampler.AddressV = *Val;
422- else
423- return make_error<InvalidRSMetadataValue>(" AddressV" );
425+ Expected<dxbc::TextureAddressMode> AddressU =
426+ extractEnumValue<dxbc::TextureAddressMode>(
427+ StaticSamplerNode, 2 , " AddressU" , dxbc::isValidAddress);
428+ if (auto E = AddressU.takeError ())
429+ return Error (std::move (E));
430+ Sampler.AddressU = *AddressU;
424431
425- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 4 ))
426- Sampler.AddressW = *Val;
427- else
428- return make_error<InvalidRSMetadataValue>(" AddressW" );
432+ Expected<dxbc::TextureAddressMode> AddressV =
433+ extractEnumValue<dxbc::TextureAddressMode>(
434+ StaticSamplerNode, 3 , " AddressV" , dxbc::isValidAddress);
435+ if (auto E = AddressV.takeError ())
436+ return Error (std::move (E));
437+ Sampler.AddressV = *AddressV;
438+
439+ Expected<dxbc::TextureAddressMode> AddressW =
440+ extractEnumValue<dxbc::TextureAddressMode>(
441+ StaticSamplerNode, 4 , " AddressW" , dxbc::isValidAddress);
442+ if (auto E = AddressW.takeError ())
443+ return Error (std::move (E));
444+ Sampler.AddressW = *AddressW;
429445
430446 if (std::optional<float > Val = extractMdFloatValue (StaticSamplerNode, 5 ))
431447 Sampler.MipLODBias = *Val;
@@ -437,15 +453,19 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
437453 else
438454 return make_error<InvalidRSMetadataValue>(" MaxAnisotropy" );
439455
440- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 7 ))
441- Sampler.ComparisonFunc = *Val;
442- else
443- return make_error<InvalidRSMetadataValue>(" ComparisonFunc" );
456+ Expected<dxbc::ComparisonFunc> ComparisonFunc =
457+ extractEnumValue<dxbc::ComparisonFunc>(
458+ StaticSamplerNode, 7 , " ComparisonFunc" , dxbc::isValidComparisonFunc);
459+ if (auto E = ComparisonFunc.takeError ())
460+ return Error (std::move (E));
461+ Sampler.ComparisonFunc = *ComparisonFunc;
444462
445- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 8 ))
446- Sampler.BorderColor = *Val;
447- else
448- return make_error<InvalidRSMetadataValue>(" ComparisonFunc" );
463+ Expected<dxbc::StaticBorderColor> BorderColor =
464+ extractEnumValue<dxbc::StaticBorderColor>(
465+ StaticSamplerNode, 8 , " BorderColor" , dxbc::isValidBorderColor);
466+ if (auto E = BorderColor.takeError ())
467+ return Error (std::move (E));
468+ Sampler.BorderColor = *BorderColor;
449469
450470 if (std::optional<float > Val = extractMdFloatValue (StaticSamplerNode, 9 ))
451471 Sampler.MinLOD = *Val;
@@ -467,10 +487,13 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
467487 else
468488 return make_error<InvalidRSMetadataValue>(" RegisterSpace" );
469489
470- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 13 ))
471- Sampler.ShaderVisibility = *Val;
472- else
473- return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
490+ Expected<dxbc::ShaderVisibility> Visibility =
491+ extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13 ,
492+ " ShaderVisibility" ,
493+ dxbc::isValidShaderVisibility);
494+ if (auto E = Visibility.takeError ())
495+ return Error (std::move (E));
496+ Sampler.ShaderVisibility = *Visibility;
474497
475498 RSD.StaticSamplers .push_back (Sampler);
476499 return Error::success ();
@@ -594,30 +617,7 @@ Error MetadataParser::validateRootSignature(
594617 }
595618 }
596619
597- for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers ) {
598- if (!hlsl::rootsig::verifySamplerFilter (Sampler.Filter ))
599- DeferredErrs =
600- joinErrors (std::move (DeferredErrs),
601- make_error<RootSignatureValidationError<uint32_t >>(
602- " Filter" , Sampler.Filter ));
603-
604- if (!hlsl::rootsig::verifyAddress (Sampler.AddressU ))
605- DeferredErrs =
606- joinErrors (std::move (DeferredErrs),
607- make_error<RootSignatureValidationError<uint32_t >>(
608- " AddressU" , Sampler.AddressU ));
609-
610- if (!hlsl::rootsig::verifyAddress (Sampler.AddressV ))
611- DeferredErrs =
612- joinErrors (std::move (DeferredErrs),
613- make_error<RootSignatureValidationError<uint32_t >>(
614- " AddressV" , Sampler.AddressV ));
615-
616- if (!hlsl::rootsig::verifyAddress (Sampler.AddressW ))
617- DeferredErrs =
618- joinErrors (std::move (DeferredErrs),
619- make_error<RootSignatureValidationError<uint32_t >>(
620- " AddressW" , Sampler.AddressW ));
620+ for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers ) {
621621
622622 if (!hlsl::rootsig::verifyMipLODBias (Sampler.MipLODBias ))
623623 DeferredErrs = joinErrors (std::move (DeferredErrs),
@@ -630,18 +630,6 @@ Error MetadataParser::validateRootSignature(
630630 make_error<RootSignatureValidationError<uint32_t >>(
631631 " MaxAnisotropy" , Sampler.MaxAnisotropy ));
632632
633- if (!hlsl::rootsig::verifyComparisonFunc (Sampler.ComparisonFunc ))
634- DeferredErrs =
635- joinErrors (std::move (DeferredErrs),
636- make_error<RootSignatureValidationError<uint32_t >>(
637- " ComparisonFunc" , Sampler.ComparisonFunc ));
638-
639- if (!hlsl::rootsig::verifyBorderColor (Sampler.BorderColor ))
640- DeferredErrs =
641- joinErrors (std::move (DeferredErrs),
642- make_error<RootSignatureValidationError<uint32_t >>(
643- " BorderColor" , Sampler.BorderColor ));
644-
645633 if (!hlsl::rootsig::verifyLOD (Sampler.MinLOD ))
646634 DeferredErrs = joinErrors (std::move (DeferredErrs),
647635 make_error<RootSignatureValidationError<float >>(
@@ -663,12 +651,6 @@ Error MetadataParser::validateRootSignature(
663651 joinErrors (std::move (DeferredErrs),
664652 make_error<RootSignatureValidationError<uint32_t >>(
665653 " RegisterSpace" , Sampler.RegisterSpace ));
666-
667- if (!dxbc::isValidShaderVisibility (Sampler.ShaderVisibility ))
668- DeferredErrs =
669- joinErrors (std::move (DeferredErrs),
670- make_error<RootSignatureValidationError<uint32_t >>(
671- " ShaderVisibility" , Sampler.ShaderVisibility ));
672654 }
673655
674656 return DeferredErrs;
0 commit comments