Skip to content

Commit 58ce3e2

Browse files
authored
[DirectX] Fix Flags validation to prevent casting into enum (#161587)
This PR changes the validation logic for Root Descriptor and Descriptor Range flags to properly check if the `uint32_t` values are within range before casting into the enums.
1 parent 74af578 commit 58ce3e2

File tree

8 files changed

+91
-21
lines changed

8 files changed

+91
-21
lines changed

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,8 +1289,8 @@ bool SemaHLSL::handleRootSignatureElements(
12891289
VerifyRegister(Loc, Descriptor->Reg.Number);
12901290
VerifySpace(Loc, Descriptor->Space);
12911291

1292-
if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(
1293-
Version, llvm::to_underlying(Descriptor->Flags)))
1292+
if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(Version,
1293+
Descriptor->Flags))
12941294
ReportFlagError(Loc);
12951295
} else if (const auto *Constants =
12961296
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {

llvm/include/llvm/BinaryFormat/DXContainer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ enum class StaticBorderColor : uint32_t {
248248

249249
bool isValidBorderColor(uint32_t V);
250250

251+
bool isValidRootDesciptorFlags(uint32_t V);
252+
253+
bool isValidDescriptorRangeFlags(uint32_t V);
254+
255+
bool isValidStaticSamplerFlags(uint32_t V);
256+
251257
LLVM_ABI ArrayRef<EnumEntry<StaticBorderColor>> getStaticBorderColors();
252258

253259
LLVM_ABI PartType parsePartType(StringRef S);

llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ LLVM_ABI bool verifyRootFlag(uint32_t Flags);
2828
LLVM_ABI bool verifyVersion(uint32_t Version);
2929
LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue);
3030
LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace);
31-
LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
31+
LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version,
32+
dxbc::RootDescriptorFlags Flags);
3233
LLVM_ABI bool verifyRangeType(uint32_t Type);
3334
LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version,
3435
dxil::ResourceClass Type,
35-
dxbc::DescriptorRangeFlags FlagsVal);
36-
LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber);
36+
dxbc::DescriptorRangeFlags Flags);
37+
LLVM_ABI bool verifyStaticSamplerFlags(uint32_t Version,
38+
dxbc::StaticSamplerFlags Flags);
3739
LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
3840
LLVM_ABI bool verifyMipLODBias(float MipLODBias);
3941
LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy);

llvm/lib/BinaryFormat/DXContainer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,27 @@ bool llvm::dxbc::isValidBorderColor(uint32_t V) {
8282
return false;
8383
}
8484

85+
bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) {
86+
using FlagT = dxbc::RootDescriptorFlags;
87+
uint32_t LargestValue =
88+
llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
89+
return V < NextPowerOf2(LargestValue);
90+
}
91+
92+
bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) {
93+
using FlagT = dxbc::DescriptorRangeFlags;
94+
uint32_t LargestValue =
95+
llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
96+
return V < NextPowerOf2(LargestValue);
97+
}
98+
99+
bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) {
100+
using FlagT = dxbc::StaticSamplerFlags;
101+
uint32_t LargestValue =
102+
llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR);
103+
return V < NextPowerOf2(LargestValue);
104+
}
105+
85106
dxbc::PartType dxbc::parsePartType(StringRef S) {
86107
#define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName)
87108
return StringSwitch<dxbc::PartType>(S)

llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature(
651651
"RegisterSpace", Descriptor.RegisterSpace));
652652

653653
if (RSD.Version > 1) {
654-
if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
655-
Descriptor.Flags))
654+
bool IsValidFlag =
655+
dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
656+
hlsl::rootsig::verifyRootDescriptorFlag(
657+
RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
658+
if (!IsValidFlag)
656659
DeferredErrs =
657660
joinErrors(std::move(DeferredErrs),
658661
make_error<RootSignatureValidationError<uint32_t>>(
@@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature(
676679
make_error<RootSignatureValidationError<uint32_t>>(
677680
"NumDescriptors", Range.NumDescriptors));
678681

679-
if (!hlsl::rootsig::verifyDescriptorRangeFlag(
680-
RSD.Version, Range.RangeType,
681-
dxbc::DescriptorRangeFlags(Range.Flags)))
682+
bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
683+
hlsl::rootsig::verifyDescriptorRangeFlag(
684+
RSD.Version, Range.RangeType,
685+
dxbc::DescriptorRangeFlags(Range.Flags));
686+
if (!IsValidFlag)
682687
DeferredErrs =
683688
joinErrors(std::move(DeferredErrs),
684689
make_error<RootSignatureValidationError<uint32_t>>(
@@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature(
731736
joinErrors(std::move(DeferredErrs),
732737
make_error<RootSignatureValidationError<uint32_t>>(
733738
"RegisterSpace", Sampler.RegisterSpace));
734-
735-
if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags))
739+
bool IsValidFlag =
740+
dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
741+
hlsl::rootsig::verifyStaticSamplerFlags(
742+
RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
743+
if (!IsValidFlag)
736744
DeferredErrs =
737745
joinErrors(std::move(DeferredErrs),
738746
make_error<RootSignatureValidationError<uint32_t>>(

llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) {
3434
return !(RegisterSpace >= 0xFFFFFFF0);
3535
}
3636

37-
bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
37+
bool verifyRootDescriptorFlag(uint32_t Version,
38+
dxbc::RootDescriptorFlags FlagsVal) {
3839
using FlagT = dxbc::RootDescriptorFlags;
3940
FlagT Flags = FlagT(FlagsVal);
4041
if (Version == 1)
@@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
5657
bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
5758
dxbc::DescriptorRangeFlags Flags) {
5859
using FlagT = dxbc::DescriptorRangeFlags;
59-
6060
const bool IsSampler = (Type == dxil::ResourceClass::Sampler);
6161

6262
if (Version == 1) {
@@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type,
113113
return (Flags & ~Mask) == FlagT::None;
114114
}
115115

116-
bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) {
117-
uint32_t LargestValue = llvm::to_underlying(
118-
dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR);
119-
if (FlagsNumber >= NextPowerOf2(LargestValue))
120-
return false;
121-
122-
dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber);
116+
bool verifyStaticSamplerFlags(uint32_t Version,
117+
dxbc::StaticSamplerFlags Flags) {
123118
if (Version <= 2)
124119
return Flags == dxbc::StaticSamplerFlags::None;
125120

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
target triple = "dxil-unknown-shadermodel6.0-compute"
4+
5+
; CHECK: error: Invalid value for DescriptorFlag: 66666
6+
; CHECK-NOT: Root Signature Definitions
7+
8+
define void @main() #0 {
9+
entry:
10+
ret void
11+
}
12+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
13+
14+
15+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
16+
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
17+
!3 = !{ !5 } ; list of root signature elements
18+
!5 = !{ !"DescriptorTable", i32 0, !6, !7 }
19+
!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 66666 }
20+
!7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
target triple = "dxil-unknown-shadermodel6.0-compute"
4+
5+
6+
; CHECK: error: Invalid value for RootDescriptorFlag: 666
7+
; CHECK-NOT: Root Signature Definitions
8+
define void @main() #0 {
9+
entry:
10+
ret void
11+
}
12+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
13+
14+
15+
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
16+
!2 = !{ ptr @main, !3, i32 2 } ; function, root signature
17+
!3 = !{ !5 } ; list of root signature elements
18+
!5 = !{ !"RootCBV", i32 0, i32 1, i32 2, i32 666 }

0 commit comments

Comments
 (0)