Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class RootSignatureParser {
/// Root Element parse methods:
std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
std::optional<llvm::hlsl::rootsig::RootParam> parseRootParam();
std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
parseDescriptorTableClause();
Expand Down
43 changes: 43 additions & 0 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ bool RootSignatureParser::parse() {
return true;
Elements.push_back(*Table);
}

if (tryConsumeExpectedToken(
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
auto RootParam = parseRootParam();
if (!RootParam.has_value())
return true;
Elements.push_back(*RootParam);
}
} while (tryConsumeExpectedToken(TokenKind::pu_comma));

return consumeExpectedToken(TokenKind::end_of_stream,
Expand Down Expand Up @@ -155,6 +163,41 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
return Constants;
}

std::optional<RootParam> RootSignatureParser::parseRootParam() {
assert((CurToken.TokKind == TokenKind::kw_CBV ||
CurToken.TokKind == TokenKind::kw_SRV ||
CurToken.TokKind == TokenKind::kw_UAV) &&
"Expects to only be invoked starting at given keyword");

TokenKind ParamKind = CurToken.TokKind;

if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
CurToken.TokKind))
return std::nullopt;

RootParam Param;
switch (ParamKind) {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::kw_CBV:
Param.Type = ParamType::CBuffer;
break;
case TokenKind::kw_SRV:
Param.Type = ParamType::SRV;
break;
case TokenKind::kw_UAV:
Param.Type = ParamType::UAV;
break;
}

if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_RootConstants))
return std::nullopt;

return Param;
}

std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
"Expects to only be invoked starting at given keyword");
Expand Down
37 changes: 37 additions & 0 deletions clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,43 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
const llvm::StringLiteral Source = R"cc(
CBV(),
SRV(),
UAV()
)cc";

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);
auto TokLoc = SourceLocation();

hlsl::RootSignatureLexer Lexer(Source, TokLoc);
SmallVector<RootElement> Elements;
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);

// Test no diagnostics produced
Consumer->setNoDiag();

ASSERT_FALSE(Parser.parse());

ASSERT_EQ(Elements.size(), 3u);

RootElement Elem = Elements[0];
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
ASSERT_EQ(std::get<RootParam>(Elem).Type, ParamType::CBuffer);

Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
ASSERT_EQ(std::get<RootParam>(Elem).Type, ParamType::SRV);

Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
ASSERT_EQ(std::get<RootParam>(Elem).Type, ParamType::UAV);

ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
Expand Down
14 changes: 9 additions & 5 deletions llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ struct RootConstants {
ShaderVisibility Visibility = ShaderVisibility::All;
};

using ParamType = llvm::dxil::ResourceClass;
struct RootParam {
ParamType Type;
};

// Models the end of a descriptor table and stores its visibility
struct DescriptorTable {
ShaderVisibility Visibility = ShaderVisibility::All;
Expand Down Expand Up @@ -125,8 +130,8 @@ struct DescriptorTableClause {
void dump(raw_ostream &OS) const;
};

/// Models RootElement : RootFlags | RootConstants | DescriptorTable
/// | DescriptorTableClause
/// Models RootElement : RootFlags | RootConstants | RootParam
/// | DescriptorTable | DescriptorTableClause
///
/// A Root Signature is modeled in-memory by an array of RootElements. These
/// aim to map closely to their DSL grammar reprsentation defined in the spec.
Expand All @@ -140,9 +145,8 @@ struct DescriptorTableClause {
/// The DescriptorTable is modelled by having its Clauses as the previous
/// RootElements in the array, and it holds a data member for the Visibility
/// parameter.
using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
DescriptorTableClause>;

using RootElement = std::variant<RootFlags, RootConstants, RootParam,
DescriptorTable, DescriptorTableClause>;
void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements);

class MetadataBuilder {
Expand Down