Skip to content

Commit d248a39

Browse files
committed
Type constants can now be used in constant expressions
1 parent 1ed1644 commit d248a39

File tree

5 files changed

+109
-0
lines changed

5 files changed

+109
-0
lines changed

include/NZSL/Ast/Transformations/ConstantPropagationTransformer.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace nzsl::Ast
4949
ExpressionTransformation Transform(ConstantExpression&& node) override;
5050
ExpressionTransformation Transform(IntrinsicExpression&& node) override;
5151
ExpressionTransformation Transform(SwizzleExpression&& node) override;
52+
ExpressionTransformation Transform(TypeConstantExpression&& node) override;
5253
ExpressionTransformation Transform(UnaryExpression&& node) override;
5354

5455
StatementTransformation Transform(BranchStatement&& node) override;

include/NZSL/Ast/Utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ namespace nzsl::Ast
6363
NZSL_API std::optional<ExpressionType> ComputeExpressionType(const SwizzleExpression& swizzleExpr, const Stringifier& typeStringifier);
6464
NZSL_API std::optional<ExpressionType> ComputeExpressionType(const UnaryExpression& unaryExpr, const Stringifier& typeStringifier);
6565
NZSL_API ExpressionType ComputeSwizzleType(const ExpressionType& type, std::size_t componentCount, const SourceLocation& sourceLocation);
66+
NZSL_API ConstantSingleValue ComputeTypeConstant(const ExpressionType& expressionType, TypeConstant typeConstant);
6667

6768
NZSL_API float LiteralToFloat32(FloatLiteral literal, const SourceLocation& sourceLocation);
6869
NZSL_API double LiteralToFloat64(FloatLiteral literal, const SourceLocation& sourceLocation);

src/NZSL/Ast/Transformations/ConstantPropagationTransformer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,11 @@ namespace nzsl::Ast
704704
return DontVisitChildren{};
705705
}
706706

707+
auto ConstantPropagationTransformer::Transform(TypeConstantExpression&& node) -> ExpressionTransformation
708+
{
709+
return ReplaceExpression{ ShaderBuilder::ConstantValue(ComputeTypeConstant(node.type, node.typeConstant)) };
710+
}
711+
707712
auto ConstantPropagationTransformer::Transform(UnaryExpression&& node) -> ExpressionTransformation
708713
{
709714
HandleExpression(node.expression);

src/NZSL/Ast/Utils.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,56 @@ namespace nzsl::Ast
413413
return baseType;
414414
}
415415

416+
ConstantSingleValue ComputeTypeConstant(const ExpressionType& expressionType, TypeConstant typeConstant)
417+
{
418+
NazaraAssert(IsPrimitiveType(expressionType));
419+
PrimitiveType primitiveType = std::get<PrimitiveType>(expressionType);
420+
421+
auto ReplaceByValue = [&](auto&& type) -> ConstantSingleValue
422+
{
423+
using T = std::decay_t<decltype(type)>;
424+
425+
if (typeConstant == TypeConstant::Max)
426+
return Nz::MaxValue<T>();
427+
428+
if (typeConstant == TypeConstant::Min)
429+
return std::numeric_limits<T>::lowest(); //< Nz::MinValue is implemented by std::numeric_limits<T>::min() which doesn't give the value we want
430+
431+
if constexpr (std::is_floating_point_v<T>)
432+
{
433+
if (typeConstant == TypeConstant::Epsilon)
434+
return std::numeric_limits<T>::epsilon();
435+
436+
if (typeConstant == TypeConstant::Infinity)
437+
return Nz::Infinity<T>();
438+
439+
if (typeConstant == TypeConstant::MinPositive)
440+
return std::numeric_limits<T>::min();
441+
442+
if (typeConstant == TypeConstant::NaN)
443+
return Nz::NaN<T>();
444+
}
445+
446+
throw std::runtime_error("unexpected type constant with type");
447+
};
448+
449+
switch (primitiveType)
450+
{
451+
case PrimitiveType::Float32: return ReplaceByValue(float{});
452+
case PrimitiveType::Float64: return ReplaceByValue(double{});
453+
case PrimitiveType::Int32: return ReplaceByValue(std::int32_t{});
454+
case PrimitiveType::UInt32: return ReplaceByValue(std::uint32_t{});
455+
456+
case PrimitiveType::Boolean:
457+
case PrimitiveType::FloatLiteral:
458+
case PrimitiveType::IntLiteral:
459+
case PrimitiveType::String:
460+
break;
461+
}
462+
463+
throw std::runtime_error("unexpected primitive type");
464+
}
465+
416466
float LiteralToFloat32(FloatLiteral literal, const SourceLocation& /*sourceLocation*/)
417467
{
418468
return static_cast<float>(literal);

tests/src/Tests/ConstantTests.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ TEST_CASE("constant", "[Shader]")
1414
[feature(float64)]
1515
module;
1616
17+
const f32Eps = f32.Epsilon;
18+
const f32Max = f32.Max;
19+
const f32Min = f32.Min;
20+
const f32MinPos = f32.MinPositive;
21+
const f32Inf = f32.Infinity;
22+
const f32NaN = f32.NaN;
23+
24+
const f64Eps = f64.Epsilon;
25+
const f64Max = f64.Max;
26+
const f64Min = f64.Min;
27+
const f64MinPos = f64.MinPositive;
28+
const f64Inf = f64.Infinity;
29+
const f64NaN = f64.NaN;
30+
31+
const i32Max = i32.Max;
32+
const i32Min = i32.Min;
33+
34+
const u32Max = u32.Max;
35+
const u32Min = u32.Min;
36+
1737
[entry(frag)]
1838
fn main()
1939
{
@@ -72,6 +92,38 @@ void main()
7292
)", {}, glslEnv);
7393

7494
ExpectNZSL(*shaderModule, R"(
95+
const f32Eps: f32 = f32.Epsilon;
96+
97+
const f32Max: f32 = f32.Max;
98+
99+
const f32Min: f32 = f32.Min;
100+
101+
const f32MinPos: f32 = f32.MinPositive;
102+
103+
const f32Inf: f32 = f32.Infinity;
104+
105+
const f32NaN: f32 = f32.NaN;
106+
107+
const f64Eps: f64 = f64.Epsilon;
108+
109+
const f64Max: f64 = f64.Max;
110+
111+
const f64Min: f64 = f64.Min;
112+
113+
const f64MinPos: f64 = f64.MinPositive;
114+
115+
const f64Inf: f64 = f64.Infinity;
116+
117+
const f64NaN: f64 = f64.NaN;
118+
119+
const i32Max: i32 = i32.Max;
120+
121+
const i32Min: i32 = i32.Min;
122+
123+
const u32Max: u32 = u32.Max;
124+
125+
const u32Min: u32 = u32.Min;
126+
75127
[entry(frag)]
76128
fn main()
77129
{

0 commit comments

Comments
 (0)