Skip to content
27 changes: 15 additions & 12 deletions llvm/include/llvm/MC/DXContainerRootSignature.h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this pr is meant to remove any dependency from MC onto Object/DXContainer.h?

Shouldn't that mean that some include is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It not from Object/DXContainer.h, it is from some values defined inside BinaryFormat/DXContainer.h mainly the RTS0 namespace

Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@ namespace llvm {
class raw_ostream;
namespace mcdxbc {

struct RootParameterHeader {
dxbc::RootParameterType ParameterType;
dxbc::ShaderVisibility ShaderVisibility;
uint32_t ParameterOffset;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used


struct RootParameterInfo {
dxbc::RTS0::v1::RootParameterHeader Header;
RootParameterHeader Header;
size_t Location;

RootParameterInfo() = default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the change, but is this actually used? I think the default constructor can be removed.


RootParameterInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location)
RootParameterInfo(RootParameterHeader Header, size_t Location)
: Header(Header), Location(Location) {}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point I think it might simplify things to just have the Info object and forgo the local header type completely, as in

struct RootParameterInfo {
  dxbc::RootParameterType Type;
  dxbc::ShaderVisibility Visibility;
  uint32_t Offset;
  size_t Location;

  RootParameterInfo(dxbc::RootParameterType Type,
                    dxbc::ShaderVisibility Visibility, uint32_t Offset,
                    size_t Location)
      : Type(Type), Visibility(Visibility), Offset(Offset), Location(Location) {
  }
};

then the addParameter functions below can just take Type/Visibility/Offset parameters directly instead of a Header object.


Expand All @@ -46,39 +52,36 @@ struct RootParametersContainer {
SmallVector<dxbc::RTS0::v2::RootDescriptor> Descriptors;
SmallVector<DescriptorTable> Tables;

void addInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location) {
void addInfo(RootParameterHeader Header, size_t Location) {
ParametersInfo.push_back(RootParameterInfo(Header, Location));
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
void addParameter(RootParameterHeader Header,
dxbc::RTS0::v1::RootConstants Constant) {
addInfo(Header, Constants.size());
Constants.push_back(Constant);
}

void addInvalidParameter(dxbc::RTS0::v1::RootParameterHeader Header) {
addInfo(Header, -1);
}
void addInvalidParameter(RootParameterHeader Header) { addInfo(Header, -1); }

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
void addParameter(RootParameterHeader Header,
dxbc::RTS0::v2::RootDescriptor Descriptor) {
addInfo(Header, Descriptors.size());
Descriptors.push_back(Descriptor);
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
DescriptorTable Table) {
void addParameter(RootParameterHeader Header, DescriptorTable Table) {
addInfo(Header, Tables.size());
Tables.push_back(Table);
}

std::pair<uint32_t, uint32_t>
std::pair<dxbc::RootParameterType, uint32_t>
getTypeAndLocForParameter(uint32_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return {Info.Header.ParameterType, Info.Location};
}

const dxbc::RTS0::v1::RootParameterHeader &getHeader(size_t Location) const {
const RootParameterHeader &getHeader(size_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return Info.Header;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could probably replace both of these with a getInfo and access the fields of the struct directly.

Expand Down
74 changes: 41 additions & 33 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}

static Expected<dxbc::ShaderVisibility>
extractShaderVisibility(MDNode *Node, unsigned int OpId) {
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
if (!dxbc::isValidShaderVisibility(*Val))
return make_error<RootSignatureValidationError<uint32_t>>(
"ShaderVisibility", *Val);
return static_cast<dxbc::ShaderVisibility>(*Val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think constructor syntax is more usual for creating an enum from an underlying value than static cast:

Suggested change
return static_cast<dxbc::ShaderVisibility>(*Val);
return dxbc::ShaderVisibility(*Val);

}
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
}

namespace {

// We use the OverloadVisit with std::visit to ensure the compiler catches if a
Expand Down Expand Up @@ -224,15 +235,16 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
if (RootConstantNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("RootConstants Element");

dxbc::RTS0::v1::RootParameterHeader Header;
mcdxbc::RootParameterHeader Header;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used - please check your compiler warnings

// The parameter offset doesn't matter here - we recalculate it during
// serialization Header.ParameterOffset = 0;
Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit);
Header.ParameterType = dxbc::RootParameterType::Constants32Bit;

if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> VisibilityOrErr =
extractShaderVisibility(RootConstantNode, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I don't think the "OrErr" in this name is adding anything - that's usually used when we immediately store to something without the suffix (So we have Foo = *FooOrErr). Just Visibility is fine for all of these.

if (auto E = VisibilityOrErr.takeError())
return Error(std::move(E));
Header.ShaderVisibility = *VisibilityOrErr;

dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Expand Down Expand Up @@ -266,26 +278,27 @@ Error MetadataParser::parseRootDescriptors(
if (RootDescriptorNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");

dxbc::RTS0::v1::RootParameterHeader Header;
mcdxbc::RootParameterHeader Header;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV);
Header.ParameterType = dxbc::RootParameterType::SRV;
break;
case RootSignatureElementKind::UAV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV);
Header.ParameterType = dxbc::RootParameterType::UAV;
break;
case RootSignatureElementKind::CBV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV);
Header.ParameterType = dxbc::RootParameterType::CBV;
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}

if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> VisibilityOrErr =
extractShaderVisibility(RootDescriptorNode, 1);
if (auto E = VisibilityOrErr.takeError())
return Error(std::move(E));
Header.ShaderVisibility = *VisibilityOrErr;

dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Expand Down Expand Up @@ -375,15 +388,16 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
if (NumOperands < 2)
return make_error<InvalidRSMetadataFormat>("Descriptor Table");

dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
mcdxbc::RootParameterHeader Header;

Expected<dxbc::ShaderVisibility> VisibilityOrErr =
extractShaderVisibility(DescriptorTableNode, 1);
if (auto E = VisibilityOrErr.takeError())
return Error(std::move(E));
Header.ShaderVisibility = *VisibilityOrErr;

mcdxbc::DescriptorTable Table;
Header.ParameterType =
to_underlying(dxbc::RootParameterType::DescriptorTable);
Header.ParameterType = dxbc::RootParameterType::DescriptorTable;

for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
Expand Down Expand Up @@ -531,20 +545,14 @@ Error MetadataParser::validateRootSignature(
}

for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"ShaderVisibility", Info.Header.ShaderVisibility));

assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");

switch (Info.Header.ParameterType) {
case dxbc::RootParameterType::Constants32Bit:
break;

case to_underlying(dxbc::RootParameterType::CBV):
case to_underlying(dxbc::RootParameterType::UAV):
case to_underlying(dxbc::RootParameterType::SRV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
Expand All @@ -569,7 +577,7 @@ Error MetadataParser::validateRootSignature(
}
break;
}
case to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
Expand Down
20 changes: 10 additions & 10 deletions llvm/lib/MC/DXContainerRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ size_t RootSignatureDesc::getSize() const {

for (const RootParameterInfo &I : ParametersContainer) {
switch (I.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
case dxbc::RootParameterType::Constants32Bit:
Size += sizeof(dxbc::RTS0::v1::RootConstants);
break;
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
if (Version == 1)
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
else
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);

break;
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
case dxbc::RootParameterType::DescriptorTable:
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(I.Location);

Expand Down Expand Up @@ -98,7 +98,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
ParametersContainer.getConstant(Loc);
support::endian::write(BOS, Constants.ShaderRegister,
Expand All @@ -109,9 +109,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Loc);

Expand All @@ -123,7 +123,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(Loc);
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
Expand Down
19 changes: 11 additions & 8 deletions llvm/lib/ObjectYAML/DXContainerEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,14 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {

for (DXContainerYAML::RootParameterLocationYaml &L :
P.RootSignature->Parameters.Locations) {
dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
L.Header.Offset};

switch (L.Header.Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
mcdxbc::RootParameterHeader Header{
static_cast<dxbc::RootParameterType>(L.Header.Type),
static_cast<dxbc::ShaderVisibility>(L.Header.Visibility),
L.Header.Offset};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will become moot after the next step (updating the YAML representation), but it would be good to make our assumptions clear here. A couple of asserts should do:

        assert(dxbc::isValidParameterType(L.Header.Type && "invalid DXContainer YAML");
        assert(dxbc::isValidShaderVisibility(L.Header.Visibility && "invalid DXContainer YAML");
        dxbc::RootParameterType Type(L.Header.Type);
        dxbc::ShaderVisibility Visibility(L.Header.Visibility);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added those assert in some point. The problem I faced was: this asserts would make me remove 2 tests, those would be removed eventually since, the yaml representation makes them impossible, but in order to keep the NFC in this PR I opted to keep the invalid parameter handling. But I do agree, it is better to make those assumptions explicit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, but since we're casting without the check here those tests are presumably hitting UB now anyway - I think it'll be better to assert and just preemptively remove those tests rather than jump through hoops here.


switch (Header.ParameterType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the default label in this switch becomes unreachable when we're switching over the enum itself, so it will need to be removed.

case dxbc::RootParameterType::Constants32Bit: {
const DXContainerYAML::RootConstantsYaml &ConstantYaml =
P.RootSignature->Parameters.getOrInsertConstants(L);
dxbc::RTS0::v1::RootConstants Constants;
Expand All @@ -289,9 +292,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Constants);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
P.RootSignature->Parameters.getOrInsertDescriptor(L);

Expand All @@ -303,7 +306,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
RS.ParametersContainer.addParameter(Header, Descriptor);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DXContainerYAML::DescriptorTableYaml &TableYaml =
P.RootSignature->Parameters.getOrInsertTable(L);
mcdxbc::DescriptorTable Table;
Expand Down
20 changes: 12 additions & 8 deletions llvm/lib/Target/DirectX/DXILRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,24 +173,28 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
const auto &[Type, Loc] =
RS.ParametersContainer.getTypeAndLocForParameter(I);
const dxbc::RTS0::v1::RootParameterHeader Header =
const mcdxbc::RootParameterHeader Header =
RS.ParametersContainer.getHeader(I);

OS << "- Parameter Type: " << Type << "\n"
<< " Shader Visibility: " << Header.ShaderVisibility << "\n";
OS << "- Parameter Type: "
<< enumToStringRef(Type, dxbc::getRootParameterTypes()) << "\n"
<< " Shader Visibility: "
<< enumToStringRef(Header.ShaderVisibility,
dxbc::getShaderVisibility())
<< "\n";

switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
RS.ParametersContainer.getConstant(Loc);
OS << " Register Space: " << Constants.RegisterSpace << "\n"
<< " Shader Register: " << Constants.ShaderRegister << "\n"
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RS.ParametersContainer.getRootDescriptor(Loc);
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
Expand All @@ -199,7 +203,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
OS << " Flags: " << Descriptor.Flags << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RS.ParametersContainer.getDescriptorTable(Loc);
OS << " NumRanges: " << Table.Ranges.size() << "\n";
Expand Down
Loading
Loading