Skip to content
49 changes: 25 additions & 24 deletions llvm/include/llvm/MC/DXContainerRootSignature.h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this pr is meant to remove any dependency from MC onto Object/DXContainer.h?

Shouldn't that mean that some include is removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It not from Object/DXContainer.h, it is from some values defined inside BinaryFormat/DXContainer.h mainly the RTS0 namespace

Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,22 @@ namespace llvm {
class raw_ostream;
namespace mcdxbc {

struct RootParameterHeader {
dxbc::RootParameterType ParameterType;
dxbc::ShaderVisibility ShaderVisibility;
uint32_t ParameterOffset;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used


struct RootParameterInfo {
dxbc::RTS0::v1::RootParameterHeader Header;
dxbc::RootParameterType Type;
dxbc::ShaderVisibility Visibility;
size_t Location;

RootParameterInfo() = default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the change, but is this actually used? I think the default constructor can be removed.


RootParameterInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location)
: Header(Header), Location(Location) {}
RootParameterInfo(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility, size_t Location)
: Type(Type), Visibility(Visibility), Location(Location) {}
};

struct DescriptorTable {
Expand All @@ -46,41 +54,34 @@ struct RootParametersContainer {
SmallVector<dxbc::RTS0::v2::RootDescriptor> Descriptors;
SmallVector<DescriptorTable> Tables;

void addInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location) {
ParametersInfo.push_back(RootParameterInfo(Header, Location));
void addInfo(dxbc::RootParameterType Type, dxbc::ShaderVisibility Visibility,
size_t Location) {
ParametersInfo.push_back(RootParameterInfo(Type, Visibility, Location));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would save a copy to use emplace_back here:

Suggested change
ParametersInfo.push_back(RootParameterInfo(Type, Visibility, Location));
ParametersInfo.emplace_back(Type, Visibility, Location);

}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
void addParameter(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility,
dxbc::RTS0::v1::RootConstants Constant) {
addInfo(Header, Constants.size());
addInfo(Type, Visibility, Constants.size());
Constants.push_back(Constant);
}

void addInvalidParameter(dxbc::RTS0::v1::RootParameterHeader Header) {
addInfo(Header, -1);
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
void addParameter(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility,
dxbc::RTS0::v2::RootDescriptor Descriptor) {
addInfo(Header, Descriptors.size());
addInfo(Type, Visibility, Descriptors.size());
Descriptors.push_back(Descriptor);
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
DescriptorTable Table) {
addInfo(Header, Tables.size());
void addParameter(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility, DescriptorTable Table) {
addInfo(Type, Visibility, Tables.size());
Tables.push_back(Table);
}

std::pair<uint32_t, uint32_t>
getTypeAndLocForParameter(uint32_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return {Info.Header.ParameterType, Info.Location};
}

const dxbc::RTS0::v1::RootParameterHeader &getHeader(size_t Location) const {
const RootParameterInfo &getInfo(uint32_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return Info.Header;
return Info;
}

const dxbc::RTS0::v1::RootConstants &getConstant(size_t Index) const {
Expand Down
83 changes: 43 additions & 40 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}

static Expected<dxbc::ShaderVisibility>
extractShaderVisibility(MDNode *Node, unsigned int OpId) {
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
if (!dxbc::isValidShaderVisibility(*Val))
return make_error<RootSignatureValidationError<uint32_t>>(
"ShaderVisibility", *Val);
return dxbc::ShaderVisibility(*Val);
}
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
}

namespace {

// We use the OverloadVisit with std::visit to ensure the compiler catches if a
Expand Down Expand Up @@ -224,15 +235,12 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
if (RootConstantNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("RootConstants Element");

dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
// serialization Header.ParameterOffset = 0;
Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit);
mcdxbc::RootParameterHeader Header;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used - please check your compiler warnings


if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> VisibilityOrErr =
extractShaderVisibility(RootConstantNode, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I don't think the "OrErr" in this name is adding anything - that's usually used when we immediately store to something without the suffix (So we have Foo = *FooOrErr). Just Visibility is fine for all of these.

if (auto E = VisibilityOrErr.takeError())
return Error(std::move(E));

dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Expand All @@ -250,7 +258,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
else
return make_error<InvalidRSMetadataValue>("Num32BitValues");

RSD.ParametersContainer.addParameter(Header, Constants);
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
*VisibilityOrErr, Constants);

return Error::success();
}
Expand All @@ -266,26 +275,26 @@ Error MetadataParser::parseRootDescriptors(
if (RootDescriptorNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");

dxbc::RTS0::v1::RootParameterHeader Header;
dxbc::RootParameterType Type;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV);
Type = dxbc::RootParameterType::SRV;
break;
case RootSignatureElementKind::UAV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV);
Type = dxbc::RootParameterType::UAV;
break;
case RootSignatureElementKind::CBV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV);
Type = dxbc::RootParameterType::CBV;
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}

if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> VisibilityOrErr =
extractShaderVisibility(RootDescriptorNode, 1);
if (auto E = VisibilityOrErr.takeError())
return Error(std::move(E));

dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Expand All @@ -299,7 +308,7 @@ Error MetadataParser::parseRootDescriptors(
return make_error<InvalidRSMetadataValue>("RegisterSpace");

if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Header, Descriptor);
RSD.ParametersContainer.addParameter(Type, *VisibilityOrErr, Descriptor);
return Error::success();
}
assert(RSD.Version > 1);
Expand All @@ -309,7 +318,7 @@ Error MetadataParser::parseRootDescriptors(
else
return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");

RSD.ParametersContainer.addParameter(Header, Descriptor);
RSD.ParametersContainer.addParameter(Type, *VisibilityOrErr, Descriptor);
return Error::success();
}

Expand Down Expand Up @@ -375,15 +384,14 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
if (NumOperands < 2)
return make_error<InvalidRSMetadataFormat>("Descriptor Table");

dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
mcdxbc::RootParameterHeader Header;

Expected<dxbc::ShaderVisibility> VisibilityOrErr =
extractShaderVisibility(DescriptorTableNode, 1);
if (auto E = VisibilityOrErr.takeError())
return Error(std::move(E));

mcdxbc::DescriptorTable Table;
Header.ParameterType =
to_underlying(dxbc::RootParameterType::DescriptorTable);

for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
Expand All @@ -395,7 +403,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
return Err;
}

RSD.ParametersContainer.addParameter(Header, Table);
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
*VisibilityOrErr, Table);
return Error::success();
}

Expand Down Expand Up @@ -531,20 +540,14 @@ Error MetadataParser::validateRootSignature(
}

for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"ShaderVisibility", Info.Header.ShaderVisibility));

assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");

switch (Info.Header.ParameterType) {
switch (Info.Type) {
case dxbc::RootParameterType::Constants32Bit:
break;

case to_underlying(dxbc::RootParameterType::CBV):
case to_underlying(dxbc::RootParameterType::UAV):
case to_underlying(dxbc::RootParameterType::SRV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
Expand All @@ -569,7 +572,7 @@ Error MetadataParser::validateRootSignature(
}
break;
}
case to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
Expand Down
36 changes: 18 additions & 18 deletions llvm/lib/MC/DXContainerRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/EndianStream.h"
#include <cstdint>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where this is used. Do we need it?


using namespace llvm;
using namespace llvm::mcdxbc;
Expand All @@ -35,20 +36,20 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);

for (const RootParameterInfo &I : ParametersContainer) {
switch (I.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
switch (I.Type) {
case dxbc::RootParameterType::Constants32Bit:
Size += sizeof(dxbc::RTS0::v1::RootConstants);
break;
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
if (Version == 1)
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
else
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);

break;
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
case dxbc::RootParameterType::DescriptorTable:
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(I.Location);

Expand Down Expand Up @@ -84,21 +85,20 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Flags, llvm::endianness::little);

SmallVector<uint32_t> ParamsOffsets;
for (const RootParameterInfo &P : ParametersContainer) {
support::endian::write(BOS, P.Header.ParameterType,
llvm::endianness::little);
support::endian::write(BOS, P.Header.ShaderVisibility,
llvm::endianness::little);
for (const RootParameterInfo &I : ParametersContainer) {
support::endian::write(BOS, I.Type, llvm::endianness::little);
support::endian::write(BOS, I.Visibility, llvm::endianness::little);

ParamsOffsets.push_back(writePlaceholder(BOS));
}

assert(NumParameters == ParamsOffsets.size());
for (size_t I = 0; I < NumParameters; ++I) {
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
const auto Info = ParametersContainer.getInfo(I);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want a reference and not a copy here, so this needs at least const auto &. Better to just spell out the type explicitly though.

Suggested change
const auto Info = ParametersContainer.getInfo(I);
const RootParameterInfo &Info = ParametersContainer.getInfo(I);

const uint32_t &Loc = Info.Location;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const uint32_t & is a weird type, this could just be uint32_t. I'd probably just update the users of Loc to use Info.Location directly though - I think it's clearer with one less variable here.

switch (Info.Type) {
case dxbc::RootParameterType::Constants32Bit: {
const dxbc::RTS0::v1::RootConstants &Constants =
ParametersContainer.getConstant(Loc);
support::endian::write(BOS, Constants.ShaderRegister,
Expand All @@ -109,9 +109,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Loc);

Expand All @@ -123,7 +123,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(Loc);
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
Expand Down
Loading