Skip to content

Commit 8037e1d

Browse files
committed
Compiler: Move loop unrolling to the LoopUnrollTransformer
1 parent 7fa799d commit 8037e1d

File tree

9 files changed

+417
-222
lines changed

9 files changed

+417
-222
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (C) 2025 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
2+
// This file is part of the "Nazara Shading Language" project
3+
// For conditions of distribution and use, see copyright notice in Config.hpp
4+
5+
#pragma once
6+
7+
#ifndef NZSL_AST_TRANSFORMATIONS_LOOPUNROLLTRANSFORMER_HPP
8+
#define NZSL_AST_TRANSFORMATIONS_LOOPUNROLLTRANSFORMER_HPP
9+
10+
#include <NZSL/Ast/Transformations/Transformer.hpp>
11+
#include <NazaraUtils/FixedVector.hpp>
12+
13+
namespace nzsl::Ast
14+
{
15+
class NZSL_API LoopUnrollTransformer final : public Transformer
16+
{
17+
public:
18+
struct Options;
19+
20+
inline LoopUnrollTransformer();
21+
22+
inline bool Transform(Module& module, TransformerContext& context, std::string* error = nullptr);
23+
bool Transform(Module& module, TransformerContext& context, const Options& options, std::string* error = nullptr);
24+
25+
struct Options
26+
{
27+
bool unrollForLoops = true;
28+
bool unrollForEachLoops = true;
29+
};
30+
31+
private:
32+
using Transformer::Transform;
33+
34+
ExpressionTransformation Transform(IdentifierValueExpression&& expression) override;
35+
36+
StatementTransformation Transform(ForEachStatement&& statement) override;
37+
StatementTransformation Transform(ForStatement&& statement) override;
38+
39+
struct VariableRemapping
40+
{
41+
IdentifierType targetIdentifierType;
42+
std::size_t sourceVariableIndex;
43+
std::size_t targetIdentifierIndex;
44+
};
45+
46+
std::size_t m_currentModuleId;
47+
Nz::HybridVector<VariableRemapping, 4> m_variableMappings;
48+
const Options* m_options;
49+
};
50+
}
51+
52+
#include <NZSL/Ast/Transformations/LoopUnrollTransformer.inl>
53+
54+
#endif // NZSL_AST_TRANSFORMATIONS_LOOPUNROLLTRANSFORMER_HPP
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright (C) 2025 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
2+
// This file is part of the "Nazara Shading Language" project
3+
// For conditions of distribution and use, see copyright notice in Config.hpp
4+
5+
namespace nzsl::Ast
6+
{
7+
inline LoopUnrollTransformer::LoopUnrollTransformer() :
8+
Transformer(TransformerFlag::IgnoreExpressions)
9+
{
10+
}
11+
12+
inline bool LoopUnrollTransformer::Transform(Module& module, TransformerContext& context, std::string* error)
13+
{
14+
return Transform(module, context, {}, error);
15+
}
16+
}

include/NZSL/Ast/Transformations/ResolveTransformer.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ namespace nzsl::Ast
3434
// TODO: Turn all of theses into separate passes
3535
std::shared_ptr<ModuleResolver> moduleResolver;
3636
bool removeAliases = false;
37-
bool unrollForLoops = true;
38-
bool unrollForEachLoops = true;
3937
};
4038

4139
private:
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
// Copyright (C) 2025 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
2+
// This file is part of the "Nazara Shading Language" project
3+
// For conditions of distribution and use, see copyright notice in Config.hpp
4+
5+
#include <NZSL/Ast/Transformations/LoopUnrollTransformer.hpp>
6+
#include <NZSL/Ast/Cloner.hpp>
7+
#include <NZSL/Ast/Utils.hpp>
8+
#include <NZSL/Lang/Errors.hpp>
9+
#include <NZSL/Ast/Transformations/TransformerContext.hpp>
10+
#include <NZSL/Ast/IndexRemapperVisitor.hpp>
11+
12+
namespace nzsl::Ast
13+
{
14+
bool LoopUnrollTransformer::Transform(Module& module, TransformerContext& context, const Options& options, std::string* error)
15+
{
16+
assert(m_variableMappings.empty());
17+
m_options = &options;
18+
19+
m_currentModuleId = 0;
20+
for (auto& importedModule : module.importedModules)
21+
{
22+
if (!TransformModule(*importedModule.module, context, error))
23+
return false;
24+
25+
m_currentModuleId++;
26+
}
27+
28+
m_currentModuleId = Nz::MaxValue();
29+
return TransformModule(module, context, error);
30+
}
31+
32+
auto LoopUnrollTransformer::Transform(IdentifierValueExpression&& expression) -> ExpressionTransformation
33+
{
34+
if (expression.identifierType != IdentifierType::Variable)
35+
return VisitChildren{};
36+
37+
for (const VariableRemapping& remapping : m_variableMappings)
38+
{
39+
if (expression.identifierIndex == remapping.sourceVariableIndex)
40+
{
41+
expression.identifierType = remapping.targetIdentifierType;
42+
expression.identifierIndex = remapping.targetIdentifierIndex;
43+
break;
44+
}
45+
}
46+
47+
return VisitChildren{};
48+
}
49+
50+
auto LoopUnrollTransformer::Transform(ForEachStatement&& forEachStatement) -> StatementTransformation
51+
{
52+
if (!m_options->unrollForEachLoops)
53+
return VisitChildren{};
54+
55+
if (!forEachStatement.unroll.HasValue() || forEachStatement.unroll.GetResultingValue() != LoopUnroll::Always)
56+
return VisitChildren{};
57+
58+
const ExpressionType* exprType = GetExpressionType(*forEachStatement.expression);
59+
if (!exprType)
60+
return VisitChildren{};
61+
62+
const ExpressionType& resolvedExprType = ResolveAlias(*exprType);
63+
64+
PushScope();
65+
ClearFlags(TransformerFlag::IgnoreExpressions);
66+
67+
// Repeat code
68+
auto multi = std::make_unique<MultiStatement>();
69+
multi->sourceLocation = forEachStatement.sourceLocation;
70+
71+
if (IsArrayType(resolvedExprType))
72+
{
73+
const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType);
74+
75+
ExpressionType constantType = UnwrapExternalType(arrayType.InnerType());
76+
77+
std::size_t mappingIndex = Nz::MaxValue();
78+
if (forEachStatement.varIndex)
79+
{
80+
mappingIndex = m_variableMappings.size();
81+
m_variableMappings.emplace_back(VariableRemapping{
82+
IdentifierType::Variable,
83+
*forEachStatement.varIndex,
84+
0
85+
});
86+
}
87+
88+
for (std::uint32_t counter = 0; counter < arrayType.length; ++counter)
89+
{
90+
PushScope();
91+
92+
auto innerMulti = std::make_unique<MultiStatement>();
93+
innerMulti->sourceLocation = forEachStatement.sourceLocation;
94+
95+
auto constant = ShaderBuilder::ConstantValue(counter, forEachStatement.sourceLocation);
96+
97+
ExpressionPtr accessIndex = ShaderBuilder::AccessIndex(Ast::Clone(*forEachStatement.expression), std::move(constant));
98+
accessIndex->cachedExpressionType = constantType;
99+
accessIndex->sourceLocation = forEachStatement.sourceLocation;
100+
101+
DeclareVariableStatementPtr elementVariable = ShaderBuilder::DeclareVariable(forEachStatement.varName, std::move(accessIndex));
102+
elementVariable->sourceLocation = forEachStatement.sourceLocation;
103+
elementVariable->varIndex = m_context->variables.Register(TransformerContext::VariableData{ constantType }, {}, forEachStatement.sourceLocation);
104+
elementVariable->varType = constantType;
105+
106+
if (mappingIndex != Nz::MaxValue<std::size_t>())
107+
m_variableMappings[mappingIndex].targetIdentifierIndex = *elementVariable->varIndex;
108+
109+
innerMulti->statements.emplace_back(std::move(elementVariable));
110+
111+
// Remap indices (as unrolling the loop will reuse them)
112+
IndexRemapperVisitor::Options indexCallbacks;
113+
indexCallbacks.indexGenerator = [this](IdentifierType identifierType, std::size_t /*previousIndex*/)
114+
{
115+
switch (identifierType)
116+
{
117+
case IdentifierType::Alias: return m_context->aliases.RegisterNewIndex(true);
118+
case IdentifierType::Constant: return m_context->constants.RegisterNewIndex(true);
119+
case IdentifierType::Function: return m_context->functions.RegisterNewIndex(true);
120+
case IdentifierType::Struct: return m_context->structs.RegisterNewIndex(true);
121+
case IdentifierType::Variable: return m_context->variables.RegisterNewIndex(true);
122+
123+
default:
124+
throw std::runtime_error("unexpected identifier type");
125+
}
126+
};
127+
128+
StatementPtr bodyStatement = Ast::Clone(*forEachStatement.statement);
129+
RemapIndices(bodyStatement, indexCallbacks);
130+
HandleStatement(bodyStatement);
131+
132+
innerMulti->statements.emplace_back(Unscope(std::move(bodyStatement)));
133+
134+
multi->statements.emplace_back(ShaderBuilder::Scoped(std::move(innerMulti)));
135+
136+
PopScope();
137+
}
138+
139+
if (mappingIndex != Nz::MaxValue<std::size_t>())
140+
{
141+
assert(m_variableMappings.size() == mappingIndex + 1);
142+
m_variableMappings.pop_back();
143+
}
144+
}
145+
146+
SetFlags(TransformerFlag::IgnoreExpressions);
147+
PopScope();
148+
149+
return ReplaceStatement{ std::move(multi) };
150+
}
151+
152+
auto LoopUnrollTransformer::Transform(ForStatement&& forStatement) -> StatementTransformation
153+
{
154+
if (!m_options->unrollForLoops)
155+
return VisitChildren{};
156+
157+
if (!forStatement.unroll.HasValue() || forStatement.unroll.GetResultingValue() != LoopUnroll::Always)
158+
return VisitChildren{};
159+
160+
std::optional<ConstantValue> fromValue = ComputeConstantValue(forStatement.fromExpr);
161+
std::optional<ConstantValue> toValue = ComputeConstantValue(forStatement.toExpr);
162+
if (!fromValue.has_value() || !toValue.has_value())
163+
return VisitChildren{};
164+
165+
std::optional<ConstantValue> stepValue;
166+
if (forStatement.stepExpr)
167+
{
168+
stepValue = ComputeConstantValue(forStatement.stepExpr);
169+
if (!stepValue.has_value())
170+
return VisitChildren{};
171+
}
172+
173+
auto multi = std::make_unique<MultiStatement>();
174+
multi->sourceLocation = forStatement.sourceLocation;
175+
176+
auto Unroll = [&](auto dummy)
177+
{
178+
using T = std::decay_t<decltype(dummy)>;
179+
180+
auto GetValue = [](const ConstantValue& constantValue, const SourceLocation& sourceLocation) -> T
181+
{
182+
if (const IntLiteral* literal = std::get_if<IntLiteral>(&constantValue))
183+
{
184+
if constexpr (std::is_same_v<T, std::int32_t>)
185+
return LiteralToInt32(*literal, sourceLocation);
186+
else if constexpr (std::is_same_v<T, std::uint32_t>)
187+
return LiteralToUInt32(*literal, sourceLocation);
188+
}
189+
190+
return std::get<T>(constantValue);
191+
};
192+
193+
T counter = GetValue(*fromValue, forStatement.fromExpr->sourceLocation);
194+
T to = GetValue(*toValue, forStatement.toExpr->sourceLocation);
195+
T step = (forStatement.stepExpr) ? GetValue(*stepValue, forStatement.stepExpr->sourceLocation) : T{ 1 };
196+
197+
std::size_t mappingIndex = Nz::MaxValue();
198+
if (forStatement.varIndex)
199+
{
200+
mappingIndex = m_variableMappings.size();
201+
m_variableMappings.emplace_back(VariableRemapping{
202+
IdentifierType::Constant,
203+
*forStatement.varIndex,
204+
0
205+
});
206+
}
207+
208+
ClearFlags(TransformerFlag::IgnoreExpressions);
209+
210+
for (; counter < to; counter += step)
211+
{
212+
PushScope();
213+
214+
auto innerMulti = std::make_unique<MultiStatement>();
215+
innerMulti->sourceLocation = forStatement.sourceLocation;
216+
217+
auto constant = ShaderBuilder::ConstantValue(counter, forStatement.sourceLocation);
218+
219+
ExpressionValue<ExpressionType> constantType;
220+
if (constant->cachedExpressionType)
221+
constantType = *constant->cachedExpressionType;
222+
223+
DeclareConstStatementPtr constDecl = ShaderBuilder::DeclareConst(forStatement.varName, std::move(constantType), std::move(constant));
224+
constDecl->constIndex = m_context->constants.Register(TransformerContext::ConstantData{ m_currentModuleId, counter }, {}, forStatement.sourceLocation);
225+
constDecl->sourceLocation = forStatement.sourceLocation;
226+
227+
if (mappingIndex != Nz::MaxValue<std::size_t>())
228+
m_variableMappings[mappingIndex].targetIdentifierIndex = *constDecl->constIndex;
229+
230+
innerMulti->statements.emplace_back(std::move(constDecl));
231+
232+
// Remap indices (as unrolling the loop will reuse them)
233+
IndexRemapperVisitor::Options indexCallbacks;
234+
indexCallbacks.indexGenerator = [this](IdentifierType identifierType, std::size_t /*previousIndex*/)
235+
{
236+
switch (identifierType)
237+
{
238+
case IdentifierType::Alias: return m_context->aliases.RegisterNewIndex(true);
239+
case IdentifierType::Constant: return m_context->constants.RegisterNewIndex(true);
240+
case IdentifierType::Function: return m_context->functions.RegisterNewIndex(true);
241+
case IdentifierType::Struct: return m_context->structs.RegisterNewIndex(true);
242+
case IdentifierType::Variable: return m_context->variables.RegisterNewIndex(true);
243+
244+
default:
245+
throw std::runtime_error("unexpected identifier type");
246+
}
247+
};
248+
249+
StatementPtr bodyStatement = Ast::Clone(*forStatement.statement);
250+
RemapIndices(bodyStatement, indexCallbacks);
251+
HandleStatement(bodyStatement);
252+
253+
innerMulti->statements.emplace_back(Unscope(std::move(bodyStatement)));
254+
255+
multi->statements.emplace_back(ShaderBuilder::Scoped(std::move(innerMulti)));
256+
257+
PopScope();
258+
}
259+
260+
SetFlags(TransformerFlag::IgnoreExpressions);
261+
262+
if (mappingIndex != Nz::MaxValue<std::size_t>())
263+
{
264+
assert(m_variableMappings.size() == mappingIndex + 1);
265+
m_variableMappings.pop_back();
266+
}
267+
};
268+
269+
ExpressionType fromExprType = GetConstantType(*fromValue);
270+
if (!IsPrimitiveType(fromExprType))
271+
throw CompilerForFromTypeExpectIntegerTypeError{ forStatement.fromExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation) };
272+
273+
PrimitiveType counterType = std::get<PrimitiveType>(fromExprType);
274+
if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32 && counterType != PrimitiveType::IntLiteral)
275+
throw CompilerForFromTypeExpectIntegerTypeError{ forStatement.fromExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation) };
276+
277+
if (counterType == PrimitiveType::IntLiteral)
278+
{
279+
// Fallback to "to" type
280+
ExpressionType toExprType = GetConstantType(*toValue);
281+
if (!IsPrimitiveType(toExprType))
282+
throw CompilerForToUnmatchingTypeError{ forStatement.toExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation), ToString(toExprType, forStatement.toExpr->sourceLocation) };
283+
284+
PrimitiveType toCounterType = std::get<PrimitiveType>(toExprType);
285+
if (toCounterType != PrimitiveType::Int32 && toCounterType != PrimitiveType::UInt32 && toCounterType != PrimitiveType::IntLiteral)
286+
throw CompilerForToUnmatchingTypeError{ forStatement.toExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation), ToString(toExprType, forStatement.toExpr->sourceLocation) };
287+
288+
counterType = toCounterType;
289+
}
290+
291+
if (counterType == PrimitiveType::IntLiteral)
292+
counterType = PrimitiveType::Int32;
293+
294+
switch (counterType)
295+
{
296+
case PrimitiveType::Int32:
297+
Unroll(std::int32_t{});
298+
break;
299+
300+
case PrimitiveType::UInt32:
301+
Unroll(std::uint32_t{});
302+
break;
303+
304+
default:
305+
throw AstInternalError{ forStatement.sourceLocation, "unexpected counter type " + ToString(counterType, forStatement.fromExpr->sourceLocation) };
306+
}
307+
308+
return ReplaceStatement{ std::move(multi) };
309+
}
310+
}

0 commit comments

Comments
 (0)