Skip to content
17 changes: 12 additions & 5 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
Expand Down Expand Up @@ -559,11 +560,17 @@ bool MetadataParser::validateRootSignature(
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");

switch (Info.Header.ParameterType) {
dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);

case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
switch (PT) {
case dxbc::RootParameterType::Constants32Bit:
// ToDo: Add proper validation.
continue;

case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
Expand All @@ -580,7 +587,7 @@ bool MetadataParser::validateRootSignature(
}
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {

bool verifyRangeType(uint32_t Type) {
switch (Type) {
case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
#define DESCRIPTOR_RANGE(Num, Val) \
case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
#include "llvm/BinaryFormat/DXContainerConstants.def"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not updating to not use to_underlying

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific case is necessary, to keep using llvm::to_underlying, since that is checking if an uint_32t is valid value for RootParametersType. However, I updated to use the tablegen definition, that way we will always have it covering all possible values.

Or we can change how the check is being done, if folks prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me like we get the same end result by doing the int->enum cast once before the switch and having the switch statement operate on the actual enumerations. I have a strong preference for not using to_underlying since it impairs the frontend's ability to generate warnings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this point, we are not sure Type is a valid DescriptorRangeType so the casting here int->enum would cause undefined behaviour.

return true;
};

Expand Down
35 changes: 23 additions & 12 deletions llvm/lib/MC/DXContainerRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Support/EndianStream.h"

using namespace llvm;
Expand Down Expand Up @@ -35,20 +36,26 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);

for (const RootParameterInfo &I : ParametersContainer) {
switch (I.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
if (!dxbc::isValidParameterType(I.Header.ParameterType))
continue;

dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(I.Header.ParameterType);

switch (PT) {
case dxbc::RootParameterType::Constants32Bit:
Size += sizeof(dxbc::RTS0::v1::RootConstants);
break;
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
if (Version == 1)
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
else
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);

break;
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
case dxbc::RootParameterType::DescriptorTable:
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(I.Location);

Expand Down Expand Up @@ -97,8 +104,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
for (size_t I = 0; I < NumParameters; ++I) {
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
if (!dxbc::isValidParameterType(Type))
continue;
dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);

switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
ParametersContainer.getConstant(Loc);
support::endian::write(BOS, Constants.ShaderRegister,
Expand All @@ -109,9 +120,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Loc);

Expand All @@ -123,7 +134,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(Loc);
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
Expand Down
28 changes: 17 additions & 11 deletions llvm/lib/ObjectYAML/DXContainerEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
L.Header.Offset};

switch (L.Header.Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
if (!dxbc::isValidParameterType(L.Header.Type)) {
// Handling invalid parameter type edge case. We intentionally let
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
// for that to be used as a testing tool more effectively.
RS.ParametersContainer.addInvalidParameter(Header);
continue;
}

dxbc::RootParameterType ParameterType =
static_cast<dxbc::RootParameterType>(L.Header.Type);

switch (ParameterType) {
case dxbc::RootParameterType::Constants32Bit: {
const DXContainerYAML::RootConstantsYaml &ConstantYaml =
P.RootSignature->Parameters.getOrInsertConstants(L);
dxbc::RTS0::v1::RootConstants Constants;
Expand All @@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Constants);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
P.RootSignature->Parameters.getOrInsertDescriptor(L);

Expand All @@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Descriptor);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DXContainerYAML::DescriptorTableYaml &TableYaml =
P.RootSignature->Parameters.getOrInsertTable(L);
mcdxbc::DescriptorTable Table;
Expand All @@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Table);
break;
}
default:
// Handling invalid parameter type edge case. We intentionally let
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
// for that to be used as a testing tool more effectively.
RS.ParametersContainer.addInvalidParameter(Header);
}
}

Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/ObjectYAML/DXContainerYAML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,22 +424,28 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
IO.mapRequired("ParameterType", L.Header.Type);
IO.mapRequired("ShaderVisibility", L.Header.Visibility);

switch (L.Header.Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
if (!dxbc::isValidParameterType(L.Header.Type))
return;
dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(L.Header.Type);

// We allow ParameterType to be invalid here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't match the implementation. It returns above if the parameter type is invalid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the comments to make it clearer what the current behaviour is.

switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
DXContainerYAML::RootConstantsYaml &Constants =
S.Parameters.getOrInsertConstants(L);
IO.mapRequired("Constants", Constants);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
DXContainerYAML::RootDescriptorYaml &Descriptor =
S.Parameters.getOrInsertDescriptor(L);
IO.mapRequired("Descriptor", Descriptor);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
DXContainerYAML::DescriptorTableYaml &Table =
S.Parameters.getOrInsertTable(L);
IO.mapRequired("Table", Table);
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/DirectX/DXILRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,20 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << "- Parameter Type: " << Type << "\n"
<< " Shader Visibility: " << Header.ShaderVisibility << "\n";

switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type");
dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
RS.ParametersContainer.getConstant(Loc);
OS << " Register Space: " << Constants.RegisterSpace << "\n"
<< " Shader Register: " << Constants.ShaderRegister << "\n"
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RS.ParametersContainer.getRootDescriptor(Loc);
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
Expand All @@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << " Flags: " << Descriptor.Flags << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RS.ParametersContainer.getDescriptorTable(Loc);
OS << " NumRanges: " << Table.Ranges.size() << "\n";
Expand Down
Loading