diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h index 4fbc0cf1e5954..82890bf814935 100644 --- a/llvm/include/llvm/BinaryFormat/DXContainer.h +++ b/llvm/include/llvm/BinaryFormat/DXContainer.h @@ -602,7 +602,7 @@ struct RootDescriptor : public v1::RootDescriptor { uint32_t Flags; RootDescriptor() = default; - RootDescriptor(v1::RootDescriptor &Base) + explicit RootDescriptor(v1::RootDescriptor &Base) : v1::RootDescriptor(Base), Flags(0u) {} void swapBytes() { diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h index e3c4900ba175c..3496b5fff398f 100644 --- a/llvm/include/llvm/MC/DXContainerRootSignature.h +++ b/llvm/include/llvm/MC/DXContainerRootSignature.h @@ -15,12 +15,69 @@ namespace llvm { class raw_ostream; namespace mcdxbc { -struct RootParameter { +struct RootParameterInfo { dxbc::RootParameterHeader Header; - union { - dxbc::RootConstants Constants; - dxbc::RTS0::v2::RootDescriptor Descriptor; - }; + size_t Location; + + RootParameterInfo() = default; + + RootParameterInfo(dxbc::RootParameterHeader Header, size_t Location) + : Header(Header), Location(Location) {} +}; + +struct RootParametersContainer { + SmallVector ParametersInfo; + + SmallVector Constants; + SmallVector Descriptors; + + void addInfo(dxbc::RootParameterHeader Header, size_t Location) { + ParametersInfo.push_back(RootParameterInfo(Header, Location)); + } + + void addParameter(dxbc::RootParameterHeader Header, + dxbc::RootConstants Constant) { + addInfo(Header, Constants.size()); + Constants.push_back(Constant); + } + + void addInvalidParameter(dxbc::RootParameterHeader Header) { + addInfo(Header, -1); + } + + void addParameter(dxbc::RootParameterHeader Header, + dxbc::RTS0::v2::RootDescriptor Descriptor) { + addInfo(Header, Descriptors.size()); + Descriptors.push_back(Descriptor); + } + + const std::pair + getTypeAndLocForParameter(uint32_t Location) const { + const RootParameterInfo &Info = ParametersInfo[Location]; + return {Info.Header.ParameterType, Info.Location}; + } + + const dxbc::RootParameterHeader &getHeader(size_t Location) const { + const RootParameterInfo &Info = ParametersInfo[Location]; + return Info.Header; + } + + const dxbc::RootConstants &getConstant(size_t Index) const { + return Constants[Index]; + } + + const dxbc::RTS0::v2::RootDescriptor &getRootDescriptor(size_t Index) const { + return Descriptors[Index]; + } + + size_t size() const { return ParametersInfo.size(); } + + SmallVector::const_iterator begin() const { + return ParametersInfo.begin(); + } + SmallVector::const_iterator end() const { + return ParametersInfo.end(); + } }; struct RootSignatureDesc { @@ -29,7 +86,7 @@ struct RootSignatureDesc { uint32_t RootParameterOffset = 0U; uint32_t StaticSamplersOffset = 0u; uint32_t NumStaticSamplers = 0u; - SmallVector Parameters; + mcdxbc::RootParametersContainer ParametersContainer; void write(raw_ostream &OS) const; diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp index 161711a79e467..a9394541d18da 100644 --- a/llvm/lib/MC/DXContainerRootSignature.cpp +++ b/llvm/lib/MC/DXContainerRootSignature.cpp @@ -30,10 +30,10 @@ static void rewriteOffsetToCurrentByte(raw_svector_ostream &Stream, size_t RootSignatureDesc::getSize() const { size_t Size = sizeof(dxbc::RootSignatureHeader) + - Parameters.size() * sizeof(dxbc::RootParameterHeader); + ParametersContainer.size() * sizeof(dxbc::RootParameterHeader); - for (const mcdxbc::RootParameter &P : Parameters) { - switch (P.Header.ParameterType) { + for (const RootParameterInfo &I : ParametersContainer) { + switch (I.Header.ParameterType) { case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): Size += sizeof(dxbc::RootConstants); break; @@ -56,7 +56,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const { raw_svector_ostream BOS(Storage); BOS.reserveExtraSpace(getSize()); - const uint32_t NumParameters = Parameters.size(); + const uint32_t NumParameters = ParametersContainer.size(); support::endian::write(BOS, Version, llvm::endianness::little); support::endian::write(BOS, NumParameters, llvm::endianness::little); @@ -66,7 +66,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const { support::endian::write(BOS, Flags, llvm::endianness::little); SmallVector ParamsOffsets; - for (const mcdxbc::RootParameter &P : Parameters) { + for (const RootParameterInfo &P : ParametersContainer) { support::endian::write(BOS, P.Header.ParameterType, llvm::endianness::little); support::endian::write(BOS, P.Header.ShaderVisibility, @@ -78,27 +78,33 @@ void RootSignatureDesc::write(raw_ostream &OS) const { assert(NumParameters == ParamsOffsets.size()); for (size_t I = 0; I < NumParameters; ++I) { rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]); - const mcdxbc::RootParameter &P = Parameters[I]; - - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - support::endian::write(BOS, P.Constants.ShaderRegister, + const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I); + switch (Type) { + case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { + const dxbc::RootConstants &Constants = + ParametersContainer.getConstant(Loc); + support::endian::write(BOS, Constants.ShaderRegister, llvm::endianness::little); - support::endian::write(BOS, P.Constants.RegisterSpace, + support::endian::write(BOS, Constants.RegisterSpace, llvm::endianness::little); - support::endian::write(BOS, P.Constants.Num32BitValues, + support::endian::write(BOS, Constants.Num32BitValues, llvm::endianness::little); break; + } case llvm::to_underlying(dxbc::RootParameterType::CBV): case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - support::endian::write(BOS, P.Descriptor.ShaderRegister, + case llvm::to_underlying(dxbc::RootParameterType::UAV): { + const dxbc::RTS0::v2::RootDescriptor &Descriptor = + ParametersContainer.getRootDescriptor(Loc); + + support::endian::write(BOS, Descriptor.ShaderRegister, llvm::endianness::little); - support::endian::write(BOS, P.Descriptor.RegisterSpace, + support::endian::write(BOS, Descriptor.RegisterSpace, llvm::endianness::little); if (Version > 1) - support::endian::write(BOS, P.Descriptor.Flags, - llvm::endianness::little); + support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little); + break; + } } } assert(Storage.size() == getSize()); diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp index 239ee9e3de9b1..c00cd3e08d59d 100644 --- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp +++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp @@ -274,27 +274,33 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset; for (const auto &Param : P.RootSignature->Parameters) { - mcdxbc::RootParameter NewParam; - NewParam.Header = dxbc::RootParameterHeader{ - Param.Type, Param.Visibility, Param.Offset}; + dxbc::RootParameterHeader Header{Param.Type, Param.Visibility, + Param.Offset}; switch (Param.Type) { case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues; - NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace; - NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister; + dxbc::RootConstants Constants; + Constants.Num32BitValues = Param.Constants.Num32BitValues; + Constants.RegisterSpace = Param.Constants.RegisterSpace; + Constants.ShaderRegister = Param.Constants.ShaderRegister; + RS.ParametersContainer.addParameter(Header, Constants); break; case llvm::to_underlying(dxbc::RootParameterType::SRV): case llvm::to_underlying(dxbc::RootParameterType::UAV): case llvm::to_underlying(dxbc::RootParameterType::CBV): - NewParam.Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; - NewParam.Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; - if (P.RootSignature->Version > 1) - NewParam.Descriptor.Flags = Param.Descriptor.getEncodedFlags(); + dxbc::RTS0::v2::RootDescriptor Descriptor; + Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; + Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; + if (RS.Version > 1) + Descriptor.Flags = Param.Descriptor.getEncodedFlags(); + RS.ParametersContainer.addParameter(Header, Descriptor); break; + default: + // Handling invalid parameter type edge case. We intentionally let + // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order + // for that to be used as a testing tool more effectively. + RS.ParametersContainer.addInvalidParameter(Header); } - - RS.Parameters.push_back(NewParam); } RS.write(OS); diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index ef299c17baf76..43e06ee278b49 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -75,31 +75,34 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (RootConstantNode->getNumOperands() != 5) return reportError(Ctx, "Invalid format for RootConstants Element"); - mcdxbc::RootParameter NewParameter; - NewParameter.Header.ParameterType = + dxbc::RootParameterHeader Header; + // The parameter offset doesn't matter here - we recalculate it during + // serialization Header.ParameterOffset = 0; + Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); if (std::optional Val = extractMdIntValue(RootConstantNode, 1)) - NewParameter.Header.ShaderVisibility = *Val; + Header.ShaderVisibility = *Val; else return reportError(Ctx, "Invalid value for ShaderVisibility"); + dxbc::RootConstants Constants; if (std::optional Val = extractMdIntValue(RootConstantNode, 2)) - NewParameter.Constants.ShaderRegister = *Val; + Constants.ShaderRegister = *Val; else return reportError(Ctx, "Invalid value for ShaderRegister"); if (std::optional Val = extractMdIntValue(RootConstantNode, 3)) - NewParameter.Constants.RegisterSpace = *Val; + Constants.RegisterSpace = *Val; else return reportError(Ctx, "Invalid value for RegisterSpace"); if (std::optional Val = extractMdIntValue(RootConstantNode, 4)) - NewParameter.Constants.Num32BitValues = *Val; + Constants.Num32BitValues = *Val; else return reportError(Ctx, "Invalid value for Num32BitValues"); - RSD.Parameters.push_back(NewParameter); + RSD.ParametersContainer.addParameter(Header, Constants); return false; } @@ -164,12 +167,12 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { return reportValueError(Ctx, "RootFlags", RSD.Flags); } - for (const mcdxbc::RootParameter &P : RSD.Parameters) { - if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility)) + for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { + if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) return reportValueError(Ctx, "ShaderVisibility", - P.Header.ShaderVisibility); + Info.Header.ShaderVisibility); - assert(dxbc::isValidParameterType(P.Header.ParameterType) && + assert(dxbc::isValidParameterType(Info.Header.ParameterType) && "Invalid value for ParameterType"); } @@ -287,25 +290,33 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, OS << indent(Space) << "Version: " << RS.Version << "\n"; OS << indent(Space) << "RootParametersOffset: " << RS.RootParameterOffset << "\n"; - OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n"; + OS << indent(Space) << "NumParameters: " << RS.ParametersContainer.size() + << "\n"; Space++; - for (auto const &P : RS.Parameters) { - OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType - << "\n"; + for (size_t I = 0; I < RS.ParametersContainer.size(); I++) { + const auto &[Type, Loc] = + RS.ParametersContainer.getTypeAndLocForParameter(I); + const dxbc::RootParameterHeader Header = + RS.ParametersContainer.getHeader(I); + + OS << indent(Space) << "- Parameter Type: " << Type << "\n"; OS << indent(Space + 2) - << "Shader Visibility: " << P.Header.ShaderVisibility << "\n"; - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - OS << indent(Space + 2) - << "Register Space: " << P.Constants.RegisterSpace << "\n"; + << "Shader Visibility: " << Header.ShaderVisibility << "\n"; + + switch (Type) { + case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): { + const dxbc::RootConstants &Constants = + RS.ParametersContainer.getConstant(Loc); + OS << indent(Space + 2) << "Register Space: " << Constants.RegisterSpace + << "\n"; OS << indent(Space + 2) - << "Shader Register: " << P.Constants.ShaderRegister << "\n"; + << "Shader Register: " << Constants.ShaderRegister << "\n"; OS << indent(Space + 2) - << "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n"; - break; + << "Num 32 Bit Values: " << Constants.Num32BitValues << "\n"; } + } + Space--; } - Space--; OS << indent(Space) << "NumStaticSamplers: " << 0 << "\n"; OS << indent(Space) << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n"; @@ -313,7 +324,6 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, Space--; // end root signature header } - return PreservedAnalyses::all(); }