@@ -51,13 +51,13 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5151 return NodeText->getString ();
5252}
5353
54- static Expected<dxbc::ShaderVisibility>
55- extractShaderVisibility (MDNode *Node, unsigned int OpId) {
54+ template <typename T, std::enable_if_t <std::is_enum_v<T>, int > = 0 >
55+ Expected<T> extractEnumValue (MDNode *Node, unsigned int OpId, StringRef ErrText,
56+ llvm::function_ref<bool (uint32_t )> VerifyFn) {
5657 if (std::optional<uint32_t > Val = extractMdIntValue (Node, OpId)) {
57- if (!dxbc::isValidShaderVisibility (*Val))
58- return make_error<RootSignatureValidationError<uint32_t >>(
59- " ShaderVisibility" , *Val);
60- return dxbc::ShaderVisibility (*Val);
58+ if (!VerifyFn (*Val))
59+ return make_error<RootSignatureValidationError<uint32_t >>(ErrText, *Val);
60+ return static_cast <T>(*Val);
6161 }
6262 return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
6363}
@@ -236,7 +236,9 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
236236 return make_error<InvalidRSMetadataFormat>(" RootConstants Element" );
237237
238238 Expected<dxbc::ShaderVisibility> Visibility =
239- extractShaderVisibility (RootConstantNode, 1 );
239+ extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1 ,
240+ " ShaderVisibility" ,
241+ dxbc::isValidShaderVisibility);
240242 if (auto E = Visibility.takeError ())
241243 return Error (std::move (E));
242244
@@ -290,7 +292,9 @@ Error MetadataParser::parseRootDescriptors(
290292 }
291293
292294 Expected<dxbc::ShaderVisibility> Visibility =
293- extractShaderVisibility (RootDescriptorNode, 1 );
295+ extractEnumValue<dxbc::ShaderVisibility>(RootDescriptorNode, 1 ,
296+ " ShaderVisibility" ,
297+ dxbc::isValidShaderVisibility);
294298 if (auto E = Visibility.takeError ())
295299 return Error (std::move (E));
296300
@@ -377,7 +381,9 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
377381 return make_error<InvalidRSMetadataFormat>(" Descriptor Table" );
378382
379383 Expected<dxbc::ShaderVisibility> Visibility =
380- extractShaderVisibility (DescriptorTableNode, 1 );
384+ extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1 ,
385+ " ShaderVisibility" ,
386+ dxbc::isValidShaderVisibility);
381387 if (auto E = Visibility.takeError ())
382388 return Error (std::move (E));
383389
@@ -403,26 +409,34 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
403409 if (StaticSamplerNode->getNumOperands () != 14 )
404410 return make_error<InvalidRSMetadataFormat>(" Static Sampler" );
405411
406- dxbc::RTS0::v1::StaticSampler Sampler;
407- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 1 ))
408- Sampler.Filter = *Val;
409- else
410- return make_error<InvalidRSMetadataValue>(" Filter" );
412+ mcdxbc::StaticSampler Sampler;
411413
412- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 2 ))
413- Sampler.AddressU = *Val;
414- else
415- return make_error<InvalidRSMetadataValue>(" AddressU" );
414+ Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>(
415+ StaticSamplerNode, 1 , " Filter" , dxbc::isValidSamplerFilter);
416+ if (auto E = Filter.takeError ())
417+ return Error (std::move (E));
418+ Sampler.Filter = *Filter;
416419
417- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 3 ))
418- Sampler.AddressV = *Val;
419- else
420- return make_error<InvalidRSMetadataValue>(" AddressV" );
420+ Expected<dxbc::TextureAddressMode> AddressU =
421+ extractEnumValue<dxbc::TextureAddressMode>(
422+ StaticSamplerNode, 2 , " AddressU" , dxbc::isValidAddress);
423+ if (auto E = AddressU.takeError ())
424+ return Error (std::move (E));
425+ Sampler.AddressU = *AddressU;
421426
422- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 4 ))
423- Sampler.AddressW = *Val;
424- else
425- return make_error<InvalidRSMetadataValue>(" AddressW" );
427+ Expected<dxbc::TextureAddressMode> AddressV =
428+ extractEnumValue<dxbc::TextureAddressMode>(
429+ StaticSamplerNode, 3 , " AddressV" , dxbc::isValidAddress);
430+ if (auto E = AddressV.takeError ())
431+ return Error (std::move (E));
432+ Sampler.AddressV = *AddressV;
433+
434+ Expected<dxbc::TextureAddressMode> AddressW =
435+ extractEnumValue<dxbc::TextureAddressMode>(
436+ StaticSamplerNode, 4 , " AddressW" , dxbc::isValidAddress);
437+ if (auto E = AddressW.takeError ())
438+ return Error (std::move (E));
439+ Sampler.AddressW = *AddressW;
426440
427441 if (std::optional<float > Val = extractMdFloatValue (StaticSamplerNode, 5 ))
428442 Sampler.MipLODBias = *Val;
@@ -434,15 +448,19 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
434448 else
435449 return make_error<InvalidRSMetadataValue>(" MaxAnisotropy" );
436450
437- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 7 ))
438- Sampler.ComparisonFunc = *Val;
439- else
440- return make_error<InvalidRSMetadataValue>(" ComparisonFunc" );
451+ Expected<dxbc::ComparisonFunc> ComparisonFunc =
452+ extractEnumValue<dxbc::ComparisonFunc>(
453+ StaticSamplerNode, 7 , " ComparisonFunc" , dxbc::isValidComparisonFunc);
454+ if (auto E = ComparisonFunc.takeError ())
455+ return Error (std::move (E));
456+ Sampler.ComparisonFunc = *ComparisonFunc;
441457
442- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 8 ))
443- Sampler.BorderColor = *Val;
444- else
445- return make_error<InvalidRSMetadataValue>(" ComparisonFunc" );
458+ Expected<dxbc::StaticBorderColor> BorderColor =
459+ extractEnumValue<dxbc::StaticBorderColor>(
460+ StaticSamplerNode, 8 , " BorderColor" , dxbc::isValidBorderColor);
461+ if (auto E = BorderColor.takeError ())
462+ return Error (std::move (E));
463+ Sampler.BorderColor = *BorderColor;
446464
447465 if (std::optional<float > Val = extractMdFloatValue (StaticSamplerNode, 9 ))
448466 Sampler.MinLOD = *Val;
@@ -464,10 +482,13 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
464482 else
465483 return make_error<InvalidRSMetadataValue>(" RegisterSpace" );
466484
467- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 13 ))
468- Sampler.ShaderVisibility = *Val;
469- else
470- return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
485+ Expected<dxbc::ShaderVisibility> Visibility =
486+ extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13 ,
487+ " ShaderVisibility" ,
488+ dxbc::isValidShaderVisibility);
489+ if (auto E = Visibility.takeError ())
490+ return Error (std::move (E));
491+ Sampler.ShaderVisibility = *Visibility;
471492
472493 RSD.StaticSamplers .push_back (Sampler);
473494 return Error::success ();
@@ -591,30 +612,7 @@ Error MetadataParser::validateRootSignature(
591612 }
592613 }
593614
594- for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers ) {
595- if (!hlsl::rootsig::verifySamplerFilter (Sampler.Filter ))
596- DeferredErrs =
597- joinErrors (std::move (DeferredErrs),
598- make_error<RootSignatureValidationError<uint32_t >>(
599- " Filter" , Sampler.Filter ));
600-
601- if (!hlsl::rootsig::verifyAddress (Sampler.AddressU ))
602- DeferredErrs =
603- joinErrors (std::move (DeferredErrs),
604- make_error<RootSignatureValidationError<uint32_t >>(
605- " AddressU" , Sampler.AddressU ));
606-
607- if (!hlsl::rootsig::verifyAddress (Sampler.AddressV ))
608- DeferredErrs =
609- joinErrors (std::move (DeferredErrs),
610- make_error<RootSignatureValidationError<uint32_t >>(
611- " AddressV" , Sampler.AddressV ));
612-
613- if (!hlsl::rootsig::verifyAddress (Sampler.AddressW ))
614- DeferredErrs =
615- joinErrors (std::move (DeferredErrs),
616- make_error<RootSignatureValidationError<uint32_t >>(
617- " AddressW" , Sampler.AddressW ));
615+ for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers ) {
618616
619617 if (!hlsl::rootsig::verifyMipLODBias (Sampler.MipLODBias ))
620618 DeferredErrs = joinErrors (std::move (DeferredErrs),
@@ -627,18 +625,6 @@ Error MetadataParser::validateRootSignature(
627625 make_error<RootSignatureValidationError<uint32_t >>(
628626 " MaxAnisotropy" , Sampler.MaxAnisotropy ));
629627
630- if (!hlsl::rootsig::verifyComparisonFunc (Sampler.ComparisonFunc ))
631- DeferredErrs =
632- joinErrors (std::move (DeferredErrs),
633- make_error<RootSignatureValidationError<uint32_t >>(
634- " ComparisonFunc" , Sampler.ComparisonFunc ));
635-
636- if (!hlsl::rootsig::verifyBorderColor (Sampler.BorderColor ))
637- DeferredErrs =
638- joinErrors (std::move (DeferredErrs),
639- make_error<RootSignatureValidationError<uint32_t >>(
640- " BorderColor" , Sampler.BorderColor ));
641-
642628 if (!hlsl::rootsig::verifyLOD (Sampler.MinLOD ))
643629 DeferredErrs = joinErrors (std::move (DeferredErrs),
644630 make_error<RootSignatureValidationError<float >>(
@@ -660,12 +646,6 @@ Error MetadataParser::validateRootSignature(
660646 joinErrors (std::move (DeferredErrs),
661647 make_error<RootSignatureValidationError<uint32_t >>(
662648 " RegisterSpace" , Sampler.RegisterSpace ));
663-
664- if (!dxbc::isValidShaderVisibility (Sampler.ShaderVisibility ))
665- DeferredErrs =
666- joinErrors (std::move (DeferredErrs),
667- make_error<RootSignatureValidationError<uint32_t >>(
668- " ShaderVisibility" , Sampler.ShaderVisibility ));
669649 }
670650
671651 return DeferredErrs;
0 commit comments