Skip to content

Commit d4608c3

Browse files
committed
add support for parsing a float for float param
1 parent 9be00bd commit d4608c3

File tree

3 files changed

+97
-9
lines changed

3 files changed

+97
-9
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ class RootSignatureParser {
130130
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
131131
/// 32-bit integer
132132
std::optional<uint32_t> handleUIntLiteral();
133-
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
133+
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a signed
134134
/// 32-bit integer
135135
std::optional<int32_t> handleIntLiteral(bool Negated);
136+
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a float
137+
std::optional<float> handleFloatLiteral(bool Negated);
136138

137139
/// Flags may specify the value of '0' to denote that there should be no
138140
/// flags set.

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <float.h>
10+
911
#include "clang/Parse/ParseHLSLRootSignature.h"
1012

1113
#include "clang/Lex/LiteralSupport.h"
@@ -734,7 +736,8 @@ std::optional<float> RootSignatureParser::parseFloatParam() {
734736
assert(CurToken.TokKind == TokenKind::pu_equal &&
735737
"Expects to only be invoked starting at given keyword");
736738
// Consume sign modifier
737-
bool Signed = tryConsumeExpectedToken({TokenKind::pu_plus, TokenKind::pu_minus});
739+
bool Signed =
740+
tryConsumeExpectedToken({TokenKind::pu_plus, TokenKind::pu_minus});
738741
bool Negated = Signed && CurToken.TokKind == TokenKind::pu_minus;
739742

740743
// Handle an uint and interpret it as a float
@@ -747,8 +750,12 @@ std::optional<float> RootSignatureParser::parseFloatParam() {
747750
auto Int = handleIntLiteral(Negated);
748751
if (!Int.has_value())
749752
return std::nullopt;
750-
751753
return (float)Int.value();
754+
} else if (tryConsumeExpectedToken(TokenKind::float_literal)) {
755+
auto Float = handleFloatLiteral(Negated);
756+
if (!Float.has_value())
757+
return std::nullopt;
758+
return Float.value();
752759
}
753760

754761
return std::nullopt;
@@ -864,9 +871,10 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
864871
PP.getSourceManager(), PP.getLangOpts(),
865872
PP.getTargetInfo(), PP.getDiagnostics());
866873
if (Literal.hadError)
867-
return true; // Error has already been reported so just return
874+
return std::nullopt; // Error has already been reported so just return
868875

869-
assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
876+
assert(Literal.isIntegerLiteral() &&
877+
"NumSpelling can only consist of digits");
870878

871879
llvm::APSInt Val = llvm::APSInt(32, false);
872880
if (Literal.GetIntegerValue(Val)) {
@@ -886,9 +894,10 @@ std::optional<int32_t> RootSignatureParser::handleIntLiteral(bool Negated) {
886894
PP.getSourceManager(), PP.getLangOpts(),
887895
PP.getTargetInfo(), PP.getDiagnostics());
888896
if (Literal.hadError)
889-
return true; // Error has already been reported so just return
897+
return std::nullopt; // Error has already been reported so just return
890898

891-
assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
899+
assert(Literal.isIntegerLiteral() &&
900+
"NumSpelling can only consist of digits");
892901

893902
llvm::APSInt Val = llvm::APSInt(32, true);
894903
if (Literal.GetIntegerValue(Val) || INT32_MAX < Val.getExtValue()) {
@@ -900,11 +909,48 @@ std::optional<int32_t> RootSignatureParser::handleIntLiteral(bool Negated) {
900909
}
901910

902911
if (Negated)
903-
return static_cast<int32_t>((-Val).getExtValue());
912+
Val = -Val;
904913

905914
return static_cast<int32_t>(Val.getExtValue());
906915
}
907916

917+
std::optional<float> RootSignatureParser::handleFloatLiteral(bool Negated) {
918+
// Parse the numeric value and do semantic checks on its specification
919+
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
920+
PP.getSourceManager(), PP.getLangOpts(),
921+
PP.getTargetInfo(), PP.getDiagnostics());
922+
if (Literal.hadError)
923+
return std::nullopt; // Error has already been reported so just return
924+
925+
assert(Literal.isFloatingLiteral() &&
926+
"NumSpelling is consistent with isNumberChar in "
927+
"LexHLSLRootSignature.cpp");
928+
929+
// DXC used `strtod` to convert the token string to a float which corresponds
930+
// to:
931+
auto DXCSemantics = llvm::APFloat::Semantics::S_IEEEdouble;
932+
auto DXCRoundingMode = llvm::RoundingMode::NearestTiesToEven;
933+
934+
llvm::APFloat Val =
935+
llvm::APFloat(llvm::APFloat::EnumToSemantics(DXCSemantics));
936+
llvm::APFloat::opStatus Status = Literal.GetFloatValue(Val, DXCRoundingMode);
937+
938+
// The float is valid with opInexect as this just denotes if rounding occured
939+
if (Status != llvm::APFloat::opStatus::opOK &&
940+
Status != llvm::APFloat::opStatus::opInexact)
941+
return std::nullopt;
942+
943+
if (Negated)
944+
Val = -Val;
945+
946+
double DoubleVal = Val.convertToDouble();
947+
if (FLT_MAX < DoubleVal || DoubleVal < -FLT_MAX) {
948+
return std::nullopt;
949+
}
950+
951+
return static_cast<float>(DoubleVal);
952+
}
953+
908954
bool RootSignatureParser::verifyZeroFlag() {
909955
assert(CurToken.TokKind == TokenKind::int_literal);
910956
auto X = handleUIntLiteral();

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,15 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseFloatsTest) {
256256
const llvm::StringLiteral Source = R"cc(
257257
StaticSampler(s0, mipLODBias = 0),
258258
StaticSampler(s0, mipLODBias = +1),
259-
StaticSampler(s0, mipLODBias = -1)
259+
StaticSampler(s0, mipLODBias = -1),
260+
StaticSampler(s0, mipLODBias = 42.),
261+
StaticSampler(s0, mipLODBias = +4.2),
262+
StaticSampler(s0, mipLODBias = -.42),
263+
StaticSampler(s0, mipLODBias = .42e+3),
264+
StaticSampler(s0, mipLODBias = 42E-12),
265+
StaticSampler(s0, mipLODBias = 42.f),
266+
StaticSampler(s0, mipLODBias = 4.2F),
267+
StaticSampler(s0, mipLODBias = 42.e+10f),
260268
)cc";
261269

262270
TrivialModuleLoader ModLoader;
@@ -284,6 +292,38 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseFloatsTest) {
284292
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
285293
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, -1.f);
286294

295+
Elem = Elements[3];
296+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
297+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 42.f);
298+
299+
Elem = Elements[4];
300+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
301+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 4.2f);
302+
303+
Elem = Elements[5];
304+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
305+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, -.42f);
306+
307+
Elem = Elements[6];
308+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
309+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 420.f);
310+
311+
Elem = Elements[7];
312+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
313+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 0.000000000042f);
314+
315+
Elem = Elements[8];
316+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
317+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 42.f);
318+
319+
Elem = Elements[9];
320+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
321+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 4.2f);
322+
323+
Elem = Elements[10];
324+
ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem));
325+
ASSERT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 420000000000.f);
326+
287327
ASSERT_TRUE(Consumer->isSatisfied());
288328
}
289329

0 commit comments

Comments
 (0)