Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,15 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::kw_CBV:
Descriptor.Type = DescriptorType::CBuffer;
Descriptor.Type = ResourceClass::CBuffer;
ExpectedReg = TokenKind::bReg;
break;
case TokenKind::kw_SRV:
Descriptor.Type = DescriptorType::SRV;
Descriptor.Type = ResourceClass::SRV;
ExpectedReg = TokenKind::tReg;
break;
case TokenKind::kw_UAV:
Descriptor.Type = DescriptorType::UAV;
Descriptor.Type = ResourceClass::UAV;
ExpectedReg = TokenKind::uReg;
break;
}
Expand Down Expand Up @@ -360,19 +360,19 @@ RootSignatureParser::parseDescriptorTableClause() {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::kw_CBV:
Clause.Type = ClauseType::CBuffer;
Clause.Type = ResourceClass::CBuffer;
ExpectedReg = TokenKind::bReg;
break;
case TokenKind::kw_SRV:
Clause.Type = ClauseType::SRV;
Clause.Type = ResourceClass::SRV;
ExpectedReg = TokenKind::tReg;
break;
case TokenKind::kw_UAV:
Clause.Type = ClauseType::UAV;
Clause.Type = ResourceClass::UAV;
ExpectedReg = TokenKind::uReg;
break;
case TokenKind::kw_Sampler:
Clause.Type = ClauseType::Sampler;
Clause.Type = ResourceClass::Sampler;
ExpectedReg = TokenKind::sReg;
break;
}
Expand Down
46 changes: 23 additions & 23 deletions clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
// First Descriptor Table with 4 elements
RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType,
RegisterType::BReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
Expand All @@ -193,7 +193,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {

Elem = Elements[1].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::SRV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType,
RegisterType::TReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
Expand All @@ -205,7 +205,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {

Elem = Elements[2].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType,
RegisterType::SReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
Expand All @@ -218,7 +218,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {

Elem = Elements[3].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::UAV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType,
RegisterType::UReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
Expand Down Expand Up @@ -445,7 +445,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
auto Elements = Parser.getElements();
RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::Sampler);
auto ValidSamplerFlags =
llvm::dxbc::DescriptorRangeFlags::DescriptorsVolatile;
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, ValidSamplerFlags);
Expand Down Expand Up @@ -591,7 +591,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {

RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::BReg);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
Expand All @@ -602,7 +602,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {

Elem = Elements[1].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::SRV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::SRV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::TReg);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 42u);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 4u);
Expand All @@ -616,7 +616,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {

Elem = Elements[2].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::UAV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::UAV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::UReg);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 34893247u);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
Expand All @@ -628,7 +628,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
RootDescriptorFlags::DataVolatile);

Elem = Elements[3].getElement();
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::BReg);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
Expand Down Expand Up @@ -696,40 +696,40 @@ TEST_F(ParseHLSLRootSignatureTest, ValidVersion10Test) {
auto DefRootDescriptorFlag = llvm::dxbc::RootDescriptorFlags::DataVolatile;
RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, DefRootDescriptorFlag);

Elem = Elements[1].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::SRV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::SRV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, DefRootDescriptorFlag);

Elem = Elements[2].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::UAV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::UAV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, DefRootDescriptorFlag);

auto ValidNonSamplerFlags =
llvm::dxbc::DescriptorRangeFlags::DescriptorsVolatile |
llvm::dxbc::DescriptorRangeFlags::DataVolatile;
Elem = Elements[3].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, ValidNonSamplerFlags);

Elem = Elements[4].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::SRV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, ValidNonSamplerFlags);

Elem = Elements[5].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::UAV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, ValidNonSamplerFlags);

Elem = Elements[6].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
llvm::dxbc::DescriptorRangeFlags::DescriptorsVolatile);

Expand Down Expand Up @@ -767,43 +767,43 @@ TEST_F(ParseHLSLRootSignatureTest, ValidVersion11Test) {
auto Elements = Parser.getElements();
RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
llvm::dxbc::RootDescriptorFlags::DataStaticWhileSetAtExecute);

Elem = Elements[1].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::SRV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::SRV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
llvm::dxbc::RootDescriptorFlags::DataStaticWhileSetAtExecute);

Elem = Elements[2].getElement();
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::UAV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, ResourceClass::UAV);
ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags,
llvm::dxbc::RootDescriptorFlags::DataVolatile);

Elem = Elements[3].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::CBuffer);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
llvm::dxbc::DescriptorRangeFlags::DataStaticWhileSetAtExecute);

Elem = Elements[4].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::SRV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
llvm::dxbc::DescriptorRangeFlags::DataStaticWhileSetAtExecute);

Elem = Elements[5].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::UAV);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
llvm::dxbc::DescriptorRangeFlags::DataVolatile);

Elem = Elements[6].getElement();
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ResourceClass::Sampler);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
llvm::dxbc::DescriptorRangeFlags::None);

Expand Down
25 changes: 13 additions & 12 deletions llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ struct RootConstants {
dxbc::ShaderVisibility Visibility = dxbc::ShaderVisibility::All;
};

enum class DescriptorType : uint8_t { SRV = 0, UAV, CBuffer };
// Models RootDescriptor : CBV | SRV | UAV, by collecting like parameters
struct RootDescriptor {
DescriptorType Type;
dxil::ResourceClass Type;
Register Reg;
uint32_t Space = 0;
dxbc::ShaderVisibility Visibility = dxbc::ShaderVisibility::All;
Expand All @@ -60,13 +59,16 @@ struct RootDescriptor {
assert(Version == llvm::dxbc::RootSignatureVersion::V1_1 &&
"Specified an invalid root signature version");
switch (Type) {
case DescriptorType::CBuffer:
case DescriptorType::SRV:
case dxil::ResourceClass::CBuffer:
case dxil::ResourceClass::SRV:
Flags = dxbc::RootDescriptorFlags::DataStaticWhileSetAtExecute;
break;
case DescriptorType::UAV:
case dxil::ResourceClass::UAV:
Flags = dxbc::RootDescriptorFlags::DataVolatile;
break;
case dxil::ResourceClass::Sampler:
llvm_unreachable(
"ResourceClass::Sampler is not valid for RootDescriptors");
}
}
};
Expand All @@ -82,9 +84,8 @@ struct DescriptorTable {
static const uint32_t NumDescriptorsUnbounded = 0xffffffff;
static const uint32_t DescriptorTableOffsetAppend = 0xffffffff;
// Models DTClause : CBV | SRV | UAV | Sampler, by collecting like parameters
using ClauseType = llvm::dxil::ResourceClass;
struct DescriptorTableClause {
ClauseType Type;
dxil::ResourceClass Type;
Register Reg;
uint32_t NumDescriptors = 1;
uint32_t Space = 0;
Expand All @@ -94,22 +95,22 @@ struct DescriptorTableClause {
void setDefaultFlags(dxbc::RootSignatureVersion Version) {
if (Version == dxbc::RootSignatureVersion::V1_0) {
Flags = dxbc::DescriptorRangeFlags::DescriptorsVolatile;
if (Type != ClauseType::Sampler)
if (Type != dxil::ResourceClass::Sampler)
Flags |= dxbc::DescriptorRangeFlags::DataVolatile;
return;
}

assert(Version == dxbc::RootSignatureVersion::V1_1 &&
"Specified an invalid root signature version");
switch (Type) {
case ClauseType::CBuffer:
case ClauseType::SRV:
case dxil::ResourceClass::CBuffer:
case dxil::ResourceClass::SRV:
Flags = dxbc::DescriptorRangeFlags::DataStaticWhileSetAtExecute;
break;
case ClauseType::UAV:
case dxil::ResourceClass::UAV:
Flags = dxbc::DescriptorRangeFlags::DataVolatile;
break;
case ClauseType::Sampler:
case dxil::ResourceClass::Sampler:
Flags = dxbc::DescriptorRangeFlags::None;
break;
}
Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ static raw_ostream &operator<<(raw_ostream &OS,
return OS;
}

static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
OS << enumToStringRef(dxil::ResourceClass(llvm::to_underlying(Type)),
dxil::getResourceClasses());
static raw_ostream &operator<<(raw_ostream &OS,
const dxil::ResourceClass &Type) {
OS << enumToStringRef(Type, dxil::getResourceClasses());

return OS;
}
Expand Down Expand Up @@ -153,8 +153,7 @@ raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
}

raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type));
OS << "Root" << Type << "(" << Descriptor.Reg
OS << "Root" << Descriptor.Type << "(" << Descriptor.Reg
<< ", space = " << Descriptor.Space
<< ", visibility = " << Descriptor.Visibility
<< ", flags = " << Descriptor.Flags << ")";
Expand Down
7 changes: 2 additions & 5 deletions llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
StringRef ResName =
enumToStringRef(dxil::ResourceClass(to_underlying(Descriptor.Type)),
dxil::getResourceClasses());
enumToStringRef(Descriptor.Type, dxil::getResourceClasses());
assert(!ResName.empty() && "Provided an invalid Resource Class");
SmallString<7> Name({"Root", ResName});
Metadata *Operands[] = {
Expand Down Expand Up @@ -161,9 +160,7 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
StringRef ResName =
enumToStringRef(dxil::ResourceClass(to_underlying(Clause.Type)),
dxil::getResourceClasses());
StringRef ResName = enumToStringRef(Clause.Type, dxil::getResourceClasses());
assert(!ResName.empty() && "Provided an invalid Resource Class");
Metadata *Operands[] = {
MDString::get(Ctx, ResName),
Expand Down
Loading