Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions llvm/include/llvm/MC/DXContainerRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@ namespace llvm {
class raw_ostream;
namespace mcdxbc {

struct RootConstants {
uint32_t ShaderRegister;
uint32_t RegisterSpace;
uint32_t Num32BitValues;
};

struct RootDescriptor {
uint32_t ShaderRegister;
uint32_t RegisterSpace;
uint32_t Flags;
};

struct RootParameterInfo {
dxbc::RTS0::v1::RootParameterHeader Header;
dxbc::RootParameterType Type;
dxbc::ShaderVisibility Visibility;
size_t Location;

RootParameterInfo() = default;

RootParameterInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location)
: Header(Header), Location(Location) {}
RootParameterInfo(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility, size_t Location)
: Type(Type), Visibility(Visibility), Location(Location) {}
};

struct DescriptorTable {
Expand All @@ -42,52 +54,44 @@ struct DescriptorTable {
struct RootParametersContainer {
SmallVector<RootParameterInfo> ParametersInfo;

SmallVector<dxbc::RTS0::v1::RootConstants> Constants;
SmallVector<dxbc::RTS0::v2::RootDescriptor> Descriptors;
SmallVector<RootConstants> Constants;
SmallVector<RootDescriptor> Descriptors;
SmallVector<DescriptorTable> Tables;

void addInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location) {
ParametersInfo.push_back(RootParameterInfo(Header, Location));
void addInfo(dxbc::RootParameterType Type, dxbc::ShaderVisibility Visibility,
size_t Location) {
ParametersInfo.emplace_back(Type, Visibility, Location);
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
dxbc::RTS0::v1::RootConstants Constant) {
addInfo(Header, Constants.size());
void addParameter(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility, RootConstants Constant) {
addInfo(Type, Visibility, Constants.size());
Constants.push_back(Constant);
}

void addInvalidParameter(dxbc::RTS0::v1::RootParameterHeader Header) {
addInfo(Header, -1);
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
dxbc::RTS0::v2::RootDescriptor Descriptor) {
addInfo(Header, Descriptors.size());
void addParameter(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility,
RootDescriptor Descriptor) {
addInfo(Type, Visibility, Descriptors.size());
Descriptors.push_back(Descriptor);
}

void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
DescriptorTable Table) {
addInfo(Header, Tables.size());
void addParameter(dxbc::RootParameterType Type,
dxbc::ShaderVisibility Visibility, DescriptorTable Table) {
addInfo(Type, Visibility, Tables.size());
Tables.push_back(Table);
}

std::pair<uint32_t, uint32_t>
getTypeAndLocForParameter(uint32_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return {Info.Header.ParameterType, Info.Location};
}

const dxbc::RTS0::v1::RootParameterHeader &getHeader(size_t Location) const {
const RootParameterInfo &getInfo(uint32_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return Info.Header;
return Info;
}

const dxbc::RTS0::v1::RootConstants &getConstant(size_t Index) const {
const RootConstants &getConstant(size_t Index) const {
return Constants[Index];
}

const dxbc::RTS0::v2::RootDescriptor &getRootDescriptor(size_t Index) const {
const RootDescriptor &getRootDescriptor(size_t Index) const {
return Descriptors[Index];
}

Expand Down
87 changes: 43 additions & 44 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}

static Expected<dxbc::ShaderVisibility>
extractShaderVisibility(MDNode *Node, unsigned int OpId) {
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
if (!dxbc::isValidShaderVisibility(*Val))
return make_error<RootSignatureValidationError<uint32_t>>(
"ShaderVisibility", *Val);
return dxbc::ShaderVisibility(*Val);
}
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
}

namespace {

// We use the OverloadVisit with std::visit to ensure the compiler catches if a
Expand Down Expand Up @@ -224,17 +235,12 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
if (RootConstantNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("RootConstants Element");

dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
// serialization Header.ParameterOffset = 0;
Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit);

if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> Visibility =
extractShaderVisibility(RootConstantNode, 1);
if (auto E = Visibility.takeError())
return Error(std::move(E));

dxbc::RTS0::v1::RootConstants Constants;
mcdxbc::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
Expand All @@ -250,7 +256,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
else
return make_error<InvalidRSMetadataValue>("Num32BitValues");

RSD.ParametersContainer.addParameter(Header, Constants);
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
*Visibility, Constants);

return Error::success();
}
Expand All @@ -266,28 +273,28 @@ Error MetadataParser::parseRootDescriptors(
if (RootDescriptorNode->getNumOperands() != 5)
return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");

dxbc::RTS0::v1::RootParameterHeader Header;
dxbc::RootParameterType Type;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV);
Type = dxbc::RootParameterType::SRV;
break;
case RootSignatureElementKind::UAV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV);
Type = dxbc::RootParameterType::UAV;
break;
case RootSignatureElementKind::CBV:
Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV);
Type = dxbc::RootParameterType::CBV;
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}

if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> Visibility =
extractShaderVisibility(RootDescriptorNode, 1);
if (auto E = Visibility.takeError())
return Error(std::move(E));

dxbc::RTS0::v2::RootDescriptor Descriptor;
mcdxbc::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
Expand All @@ -299,7 +306,7 @@ Error MetadataParser::parseRootDescriptors(
return make_error<InvalidRSMetadataValue>("RegisterSpace");

if (RSD.Version == 1) {
RSD.ParametersContainer.addParameter(Header, Descriptor);
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
return Error::success();
}
assert(RSD.Version > 1);
Expand All @@ -309,7 +316,7 @@ Error MetadataParser::parseRootDescriptors(
else
return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");

RSD.ParametersContainer.addParameter(Header, Descriptor);
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
return Error::success();
}

Expand Down Expand Up @@ -375,15 +382,12 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
if (NumOperands < 2)
return make_error<InvalidRSMetadataFormat>("Descriptor Table");

dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
Expected<dxbc::ShaderVisibility> Visibility =
extractShaderVisibility(DescriptorTableNode, 1);
if (auto E = Visibility.takeError())
return Error(std::move(E));

mcdxbc::DescriptorTable Table;
Header.ParameterType =
to_underlying(dxbc::RootParameterType::DescriptorTable);

for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
Expand All @@ -395,7 +399,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
return Err;
}

RSD.ParametersContainer.addParameter(Header, Table);
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
*Visibility, Table);
return Error::success();
}

Expand Down Expand Up @@ -531,21 +536,15 @@ Error MetadataParser::validateRootSignature(
}

for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
make_error<RootSignatureValidationError<uint32_t>>(
"ShaderVisibility", Info.Header.ShaderVisibility));

assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");

switch (Info.Header.ParameterType) {
switch (Info.Type) {
case dxbc::RootParameterType::Constants32Bit:
break;

case to_underlying(dxbc::RootParameterType::CBV):
case to_underlying(dxbc::RootParameterType::UAV):
case to_underlying(dxbc::RootParameterType::SRV): {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const mcdxbc::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
DeferredErrs =
Expand All @@ -569,7 +568,7 @@ Error MetadataParser::validateRootSignature(
}
break;
}
case to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
Expand Down
44 changes: 21 additions & 23 deletions llvm/lib/MC/DXContainerRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);

for (const RootParameterInfo &I : ParametersContainer) {
switch (I.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
switch (I.Type) {
case dxbc::RootParameterType::Constants32Bit:
Size += sizeof(dxbc::RTS0::v1::RootConstants);
break;
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
if (Version == 1)
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
else
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);

break;
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
case dxbc::RootParameterType::DescriptorTable:
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(I.Location);

Expand Down Expand Up @@ -84,23 +84,21 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Flags, llvm::endianness::little);

SmallVector<uint32_t> ParamsOffsets;
for (const RootParameterInfo &P : ParametersContainer) {
support::endian::write(BOS, P.Header.ParameterType,
llvm::endianness::little);
support::endian::write(BOS, P.Header.ShaderVisibility,
llvm::endianness::little);
for (const RootParameterInfo &I : ParametersContainer) {
support::endian::write(BOS, I.Type, llvm::endianness::little);
support::endian::write(BOS, I.Visibility, llvm::endianness::little);

ParamsOffsets.push_back(writePlaceholder(BOS));
}

assert(NumParameters == ParamsOffsets.size());
for (size_t I = 0; I < NumParameters; ++I) {
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
const dxbc::RTS0::v1::RootConstants &Constants =
ParametersContainer.getConstant(Loc);
const auto Info = ParametersContainer.getInfo(I);
switch (Info.Type) {
case dxbc::RootParameterType::Constants32Bit: {
const mcdxbc::RootConstants &Constants =
ParametersContainer.getConstant(Info.Location);
support::endian::write(BOS, Constants.ShaderRegister,
llvm::endianness::little);
support::endian::write(BOS, Constants.RegisterSpace,
Expand All @@ -109,11 +107,11 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::SRV):
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Loc);
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV: {
const mcdxbc::RootDescriptor &Descriptor =
ParametersContainer.getRootDescriptor(Info.Location);

support::endian::write(BOS, Descriptor.ShaderRegister,
llvm::endianness::little);
Expand All @@ -123,9 +121,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
case dxbc::RootParameterType::DescriptorTable: {
const DescriptorTable &Table =
ParametersContainer.getDescriptorTable(Loc);
ParametersContainer.getDescriptorTable(Info.Location);
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
llvm::endianness::little);
rewriteOffsetToCurrentByte(BOS, writePlaceholder(BOS));
Expand Down
Loading