Skip to content
11 changes: 7 additions & 4 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/HLSLBinding.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -1218,9 +1219,9 @@ bool SemaHLSL::handleRootSignatureElements(
ReportError(Loc, 0, 0xffffffef);
};

const uint32_t Version =
llvm::to_underlying(SemaRef.getLangOpts().HLSLRootSigVer);
const uint32_t VersionEnum = Version - 1;
const llvm::dxbc::RootSignatureVersion Version =
SemaRef.getLangOpts().HLSLRootSigVer;
const uint32_t VersionEnum = static_cast<uint32_t>(Version) - 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a hack used to print the minor version number. We should instead have a better enum->string conversion that can produce the correct version number.

auto ReportFlagError = [this, &HadError, VersionEnum](SourceLocation Loc) {
HadError = true;
this->Diag(Loc, diag::err_hlsl_invalid_rootsig_flag)
Expand Down Expand Up @@ -1270,7 +1271,9 @@ bool SemaHLSL::handleRootSignatureElements(
}

if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
Version, llvm::to_underlying(Clause->Type),
Version,
llvm::dxbc::DescriptorRangeType(
llvm::to_underlying(Clause->Type)),
llvm::to_underlying(Clause->Flags)))
ReportFlagError(Loc);
}
Expand Down
44 changes: 23 additions & 21 deletions llvm/include/llvm/BinaryFormat/DXContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ enum class DescriptorRangeType : uint32_t {
LLVM_ABI ArrayRef<EnumEntry<DescriptorRangeType>> getDescriptorRangeTypes();

#define ROOT_PARAMETER(Val, Enum) \
case Val: \
case dxbc::RootParameterType::Enum: \
return true;
inline bool isValidParameterType(uint32_t V) {
inline bool isValidParameterType(dxbc::RootParameterType V) {
switch (V) {
#include "DXContainerConstants.def"
}
Expand All @@ -217,9 +217,9 @@ enum class ShaderVisibility : uint32_t {
LLVM_ABI ArrayRef<EnumEntry<ShaderVisibility>> getShaderVisibility();

#define SHADER_VISIBILITY(Val, Enum) \
case Val: \
case dxbc::ShaderVisibility::Enum: \
return true;
inline bool isValidShaderVisibility(uint32_t V) {
inline bool isValidShaderVisibility(dxbc::ShaderVisibility V) {
switch (V) {
#include "DXContainerConstants.def"
}
Expand Down Expand Up @@ -254,6 +254,14 @@ enum class StaticBorderColor : uint32_t {

LLVM_ABI ArrayRef<EnumEntry<StaticBorderColor>> getStaticBorderColors();

// D3D_ROOT_SIGNATURE_VERSION
enum class RootSignatureVersion {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enum class RootSignatureVersion {
enum class RootSignatureVersion : unit32_t {

V1_0 = 0x1,
V1_1 = 0x2,
};

LLVM_ABI ArrayRef<EnumEntry<RootSignatureVersion>> getRootSignatureVersions();

LLVM_ABI PartType parsePartType(StringRef S);

struct VertexPSVInfo {
Expand Down Expand Up @@ -646,19 +654,19 @@ static_assert(sizeof(ProgramSignatureElement) == 32,
namespace RTS0 {
namespace v1 {
struct StaticSampler {
uint32_t Filter;
uint32_t AddressU;
uint32_t AddressV;
uint32_t AddressW;
dxbc::SamplerFilter Filter;
dxbc::TextureAddressMode AddressU;
dxbc::TextureAddressMode AddressV;
dxbc::TextureAddressMode AddressW;
float MipLODBias;
uint32_t MaxAnisotropy;
uint32_t ComparisonFunc;
uint32_t BorderColor;
dxbc::ComparisonFunc ComparisonFunc;
dxbc::StaticBorderColor BorderColor;
float MinLOD;
float MaxLOD;
uint32_t ShaderRegister;
uint32_t RegisterSpace;
uint32_t ShaderVisibility;
dxbc::ShaderVisibility ShaderVisibility;
void swapBytes() {
sys::swapByteOrder(Filter);
sys::swapByteOrder(AddressU);
Expand All @@ -677,7 +685,7 @@ struct StaticSampler {
};

struct DescriptorRange {
uint32_t RangeType;
dxbc::DescriptorRangeType RangeType;
uint32_t NumDescriptors;
uint32_t BaseShaderRegister;
uint32_t RegisterSpace;
Expand Down Expand Up @@ -715,8 +723,8 @@ struct RootConstants {
};

struct RootParameterHeader {
uint32_t ParameterType;
uint32_t ShaderVisibility;
dxbc::RootParameterType ParameterType;
dxbc::ShaderVisibility ShaderVisibility;
uint32_t ParameterOffset;

void swapBytes() {
Expand Down Expand Up @@ -760,7 +768,7 @@ struct RootDescriptor : public v1::RootDescriptor {
};

struct DescriptorRange {
uint32_t RangeType;
dxbc::DescriptorRangeType RangeType;
uint32_t NumDescriptors;
uint32_t BaseShaderRegister;
uint32_t RegisterSpace;
Expand All @@ -778,12 +786,6 @@ struct DescriptorRange {
} // namespace v2
} // namespace RTS0

// D3D_ROOT_SIGNATURE_VERSION
enum class RootSignatureVersion {
V1_0 = 0x1,
V1_1 = 0x2,
};

} // namespace dxbc
} // namespace llvm

Expand Down
9 changes: 5 additions & 4 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/IR/Constants.h"
#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/Support/Compiler.h"
#include <cstdint>

namespace llvm {
class LLVMContext;
Expand All @@ -28,9 +29,9 @@ class Metadata;
namespace hlsl {
namespace rootsig {

template <typename T>
template <typename T, typename ET = T>
class RootSignatureValidationError
: public ErrorInfo<RootSignatureValidationError<T>> {
: public ErrorInfo<RootSignatureValidationError<T, ET>> {
public:
static char ID;
StringRef ParamName;
Expand All @@ -40,7 +41,7 @@ class RootSignatureValidationError
: ParamName(ParamName), Value(Value) {}

void log(raw_ostream &OS) const override {
OS << "Invalid value for " << ParamName << ": " << Value;
OS << "Invalid value for " << ParamName << ": " << static_cast<ET>(Value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably eliminate the ET parameter type by doing something like this:

Suggested change
OS << "Invalid value for " << ParamName << ": " << static_cast<ET>(Value);
OS << "Invalid value for " << ParamName << ": ";
if constexpr (std::is_enum<T>)
OS << llvm::to_underlying(Value);
else
OS << Value;

}

std::error_code convertToErrorCode() const override {
Expand Down Expand Up @@ -143,7 +144,7 @@ class MetadataParser {
MetadataParser(MDNode *Root) : Root(Root) {}

LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
ParseRootSignature(uint32_t Version);
ParseRootSignature(dxbc::RootSignatureVersion Version);

private:
llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
Expand Down
19 changes: 11 additions & 8 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREVALIDATIONS_H

#include "llvm/ADT/IntervalMap.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
#include "llvm/Support/Compiler.h"

Expand All @@ -25,20 +26,22 @@ namespace rootsig {
// Basic verification of RootElements

LLVM_ABI bool verifyRootFlag(uint32_t Flags);
LLVM_ABI bool verifyVersion(uint32_t Version);
LLVM_ABI bool verifyVersion(dxbc::RootSignatureVersion Version);
LLVM_ABI bool verifyRegisterValue(uint32_t RegisterValue);
LLVM_ABI bool verifyRegisterSpace(uint32_t RegisterSpace);
LLVM_ABI bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal);
LLVM_ABI bool verifyRangeType(uint32_t Type);
LLVM_ABI bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
LLVM_ABI bool verifyRootDescriptorFlag(dxbc::RootSignatureVersion Version,
uint32_t FlagsVal);
LLVM_ABI bool verifyRangeType(dxbc::DescriptorRangeType Type);
LLVM_ABI bool verifyDescriptorRangeFlag(dxbc::RootSignatureVersion Version,
dxbc::DescriptorRangeType Type,
uint32_t FlagsVal);
LLVM_ABI bool verifyNumDescriptors(uint32_t NumDescriptors);
LLVM_ABI bool verifySamplerFilter(uint32_t Value);
LLVM_ABI bool verifyAddress(uint32_t Address);
LLVM_ABI bool verifySamplerFilter(dxbc::SamplerFilter Value);
LLVM_ABI bool verifyAddress(dxbc::TextureAddressMode Address);
LLVM_ABI bool verifyMipLODBias(float MipLODBias);
LLVM_ABI bool verifyMaxAnisotropy(uint32_t MaxAnisotropy);
LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc);
LLVM_ABI bool verifyBorderColor(uint32_t BorderColor);
LLVM_ABI bool verifyComparisonFunc(dxbc::ComparisonFunc ComparisonFunc);
LLVM_ABI bool verifyBorderColor(dxbc::StaticBorderColor BorderColor);
LLVM_ABI bool verifyLOD(float LOD);

} // namespace rootsig
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/MC/DXContainerRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct RootParametersContainer {
Tables.push_back(Table);
}

std::pair<uint32_t, uint32_t>
std::pair<dxbc::RootParameterType, uint32_t>
getTypeAndLocForParameter(uint32_t Location) const {
const RootParameterInfo &Info = ParametersInfo[Location];
return {Info.Header.ParameterType, Info.Location};
Expand Down Expand Up @@ -106,7 +106,7 @@ struct RootParametersContainer {
};
struct RootSignatureDesc {

uint32_t Version = 2U;
dxbc::RootSignatureVersion Version = dxbc::RootSignatureVersion::V1_1;
uint32_t Flags = 0U;
uint32_t RootParameterOffset = 0U;
uint32_t StaticSamplersOffset = 0u;
Expand Down
36 changes: 17 additions & 19 deletions llvm/include/llvm/Object/DXContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ struct RootParameterView {

struct RootConstantView : RootParameterView {
static bool classof(const RootParameterView *V) {
return V->Header.ParameterType ==
(uint32_t)dxbc::RootParameterType::Constants32Bit;
return V->Header.ParameterType == dxbc::RootParameterType::Constants32Bit;
}

llvm::Expected<dxbc::RTS0::v1::RootConstants> read() {
Expand All @@ -157,25 +156,24 @@ struct RootConstantView : RootParameterView {

struct RootDescriptorView : RootParameterView {
static bool classof(const RootParameterView *V) {
return (V->Header.ParameterType ==
llvm::to_underlying(dxbc::RootParameterType::CBV) ||
V->Header.ParameterType ==
llvm::to_underlying(dxbc::RootParameterType::SRV) ||
V->Header.ParameterType ==
llvm::to_underlying(dxbc::RootParameterType::UAV));
return (V->Header.ParameterType == dxbc::RootParameterType::CBV ||
V->Header.ParameterType == dxbc::RootParameterType::SRV ||
V->Header.ParameterType == dxbc::RootParameterType::UAV);
}

llvm::Expected<dxbc::RTS0::v2::RootDescriptor> read(uint32_t Version) {
if (Version == 1) {
llvm::Expected<dxbc::RTS0::v2::RootDescriptor>
read(dxbc::RootSignatureVersion Version) {
if (Version == dxbc::RootSignatureVersion::V1_0) {
auto Descriptor = readParameter<dxbc::RTS0::v1::RootDescriptor>();
if (Error E = Descriptor.takeError())
return E;
return dxbc::RTS0::v2::RootDescriptor(*Descriptor);
}
if (Version != 2)
return make_error<GenericBinaryError>("Invalid Root Signature version: " +
Twine(Version),
object_error::parse_failed);
if (Version != dxbc::RootSignatureVersion::V1_1)
return make_error<GenericBinaryError>(
"Invalid Root Signature version: " +
Twine(static_cast<uint32_t>(Version)),
object_error::parse_failed);
return readParameter<dxbc::RTS0::v2::RootDescriptor>();
}
};
Expand All @@ -192,7 +190,7 @@ template <typename T> struct DescriptorTable {
struct DescriptorTableView : RootParameterView {
static bool classof(const RootParameterView *V) {
return (V->Header.ParameterType ==
llvm::to_underlying(dxbc::RootParameterType::DescriptorTable));
dxbc::RootParameterType::DescriptorTable);
}

// Define a type alias to access the template parameter from inside classof
Expand Down Expand Up @@ -220,7 +218,7 @@ static Error parseFailed(const Twine &Msg) {

class RootSignature {
private:
uint32_t Version;
dxbc::RootSignatureVersion Version;
uint32_t NumParameters;
uint32_t RootParametersOffset;
uint32_t NumStaticSamplers;
Expand All @@ -238,7 +236,7 @@ class RootSignature {
RootSignature(StringRef PD) : PartData(PD) {}

LLVM_ABI Error parse();
uint32_t getVersion() const { return Version; }
dxbc::RootSignatureVersion getVersion() const { return Version; }
uint32_t getNumParameters() const { return NumParameters; }
uint32_t getRootParametersOffset() const { return RootParametersOffset; }
uint32_t getNumStaticSamplers() const { return NumStaticSamplers; }
Expand Down Expand Up @@ -269,7 +267,7 @@ class RootSignature {
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::SRV:
case dxbc::RootParameterType::UAV:
if (Version == 1)
if (Version == dxbc::RootSignatureVersion::V1_0)
DataSize = sizeof(dxbc::RTS0::v1::RootDescriptor);
else
DataSize = sizeof(dxbc::RTS0::v2::RootDescriptor);
Expand All @@ -281,7 +279,7 @@ class RootSignature {
uint32_t NumRanges =
support::endian::read<uint32_t, llvm::endianness::little>(
PartData.begin() + Header.ParameterOffset);
if (Version == 1)
if (Version == dxbc::RootSignatureVersion::V1_0)
DataSize = sizeof(dxbc::RTS0::v1::DescriptorRange) * NumRanges;
else
DataSize = sizeof(dxbc::RTS0::v2::DescriptorRange) * NumRanges;
Expand Down
Loading