Skip to content

[DirectX][NFC] Refactoring DirectX backend to not use llvm::to_underlying in switch cases. #151032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
Expand Down Expand Up @@ -559,11 +560,17 @@ bool MetadataParser::validateRootSignature(
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");

switch (Info.Header.ParameterType) {
dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);

case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
switch (PT) {
case dxbc::RootParameterType::Constants32Bit:
// ToDo: Add proper validation.
continue;

case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
Expand All @@ -580,7 +587,7 @@ bool MetadataParser::validateRootSignature(
}
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {

bool verifyRangeType(uint32_t Type) {
switch (Type) {
case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
#define DESCRIPTOR_RANGE(Num, Val) \
case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
#include "llvm/BinaryFormat/DXContainerConstants.def"
Comment on lines -56 to +58
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not updating to not use to_underlying

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific case is necessary, to keep using llvm::to_underlying, since that is checking if an uint_32t is valid value for RootParametersType. However, I updated to use the tablegen definition, that way we will always have it covering all possible values.

Or we can change how the check is being done, if folks prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me like we get the same end result by doing the int->enum cast once before the switch and having the switch statement operate on the actual enumerations. I have a strong preference for not using to_underlying since it impairs the frontend's ability to generate warnings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this point, we are not sure Type is a valid DescriptorRangeType so the casting here int->enum would cause undefined behaviour.

return true;
};

Expand Down
35 changes: 23 additions & 12 deletions llvm/lib/MC/DXContainerRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Support/EndianStream.h"

using namespace llvm;
Expand Down Expand Up @@ -35,20 +36,26 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);

for (const RootParameterInfo &I : ParametersContainer) {
switch (I.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
if (!dxbc::isValidParameterType(I.Header.ParameterType))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is ParameterType a uint32_t instead of being the enum type?

It seems like a lot of complexity here would go away if we just stored these as the appropriate type, like we do for enums stored in the ProgramSignatureElement structure.

I suspect it is actually impossible to get to this point in the code with a parameter type that isn't legal (if it isn't impossible it probably should be) or we should be surfacing errors.

continue;

dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(I.Header.ParameterType);

switch (PT) {
case dxbc::RootParameterType::Constants32Bit:
Size += sizeof(dxbc::RTS0::v1::RootConstants);
break;
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
if (Version == 1)
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
else
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);

break;
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
case dxbc::RootParameterType::DescriptorTable:
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(I.Location);

Expand Down Expand Up @@ -97,8 +104,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
for (size_t I = 0; I < NumParameters; ++I) {
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
if (!dxbc::isValidParameterType(Type))
continue;
dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);

switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
ParametersContainer.getConstant(Loc);
support::endian::write(BOS, Constants.ShaderRegister,
Expand All @@ -109,9 +120,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Loc);

Expand All @@ -123,7 +134,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(Loc);
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
Expand Down
28 changes: 17 additions & 11 deletions llvm/lib/ObjectYAML/DXContainerEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
L.Header.Offset};

switch (L.Header.Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
if (!dxbc::isValidParameterType(L.Header.Type)) {
// Handling invalid parameter type edge case. We intentionally let
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
// for that to be used as a testing tool more effectively.
RS.ParametersContainer.addInvalidParameter(Header);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I question the utility of this. It doesn't actually enable us to emit a binary that contains potentially well-formed parameters that our implementation doesn't understand.

That said, if we're using it for test I guess it is okay, it just has the drawback of forcing the MC code to be resilient to invalid parameter types which wouldn't be necessary if not for this.

continue;
}

dxbc::RootParameterType ParameterType =
static_cast<dxbc::RootParameterType>(L.Header.Type);

switch (ParameterType) {
case dxbc::RootParameterType::Constants32Bit: {
const DXContainerYAML::RootConstantsYaml &ConstantYaml =
P.RootSignature->Parameters.getOrInsertConstants(L);
dxbc::RTS0::v1::RootConstants Constants;
Expand All @@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Constants);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
P.RootSignature->Parameters.getOrInsertDescriptor(L);

Expand All @@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Descriptor);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DXContainerYAML::DescriptorTableYaml &TableYaml =
P.RootSignature->Parameters.getOrInsertTable(L);
mcdxbc::DescriptorTable Table;
Expand All @@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Table);
break;
}
default:
// Handling invalid parameter type edge case. We intentionally let
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
// for that to be used as a testing tool more effectively.
RS.ParametersContainer.addInvalidParameter(Header);
}
}

Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/ObjectYAML/DXContainerYAML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,22 +424,28 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
IO.mapRequired("ParameterType", L.Header.Type);
IO.mapRequired("ShaderVisibility", L.Header.Visibility);

switch (L.Header.Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
if (!dxbc::isValidParameterType(L.Header.Type))
return;
dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(L.Header.Type);

// We allow ParameterType to be invalid here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't match the implementation. It returns above if the parameter type is invalid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the comments to make it clearer what the current behaviour is.

switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
DXContainerYAML::RootConstantsYaml &Constants =
S.Parameters.getOrInsertConstants(L);
IO.mapRequired("Constants", Constants);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
DXContainerYAML::RootDescriptorYaml &Descriptor =
S.Parameters.getOrInsertDescriptor(L);
IO.mapRequired("Descriptor", Descriptor);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
DXContainerYAML::DescriptorTableYaml &Table =
S.Parameters.getOrInsertTable(L);
IO.mapRequired("Table", Table);
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Target/DirectX/DXILRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,20 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << "- Parameter Type: " << Type << "\n"
<< " Shader Visibility: " << Header.ShaderVisibility << "\n";

switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type");
dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
switch (PT) {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
RS.ParametersContainer.getConstant(Loc);
OS << " Register Space: " << Constants.RegisterSpace << "\n"
<< " Shader Register: " << Constants.ShaderRegister << "\n"
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RS.ParametersContainer.getRootDescriptor(Loc);
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
Expand All @@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << " Flags: " << Descriptor.Flags << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RS.ParametersContainer.getDescriptorTable(Loc);
OS << " NumRanges: " << Table.Ranges.size() << "\n";
Expand Down
Loading