|
| 1 | +// Copyright (C) 2025 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) |
| 2 | +// This file is part of the "Nazara Shading Language" project |
| 3 | +// For conditions of distribution and use, see copyright notice in Config.hpp |
| 4 | + |
| 5 | +#include <NZSL/Ast/Transformations/LoopUnrollTransformer.hpp> |
| 6 | +#include <NZSL/Ast/Cloner.hpp> |
| 7 | +#include <NZSL/Ast/Utils.hpp> |
| 8 | +#include <NZSL/Lang/Errors.hpp> |
| 9 | +#include <NZSL/Ast/Transformations/TransformerContext.hpp> |
| 10 | +#include <NZSL/Ast/IndexRemapperVisitor.hpp> |
| 11 | + |
| 12 | +namespace nzsl::Ast |
| 13 | +{ |
| 14 | + bool LoopUnrollTransformer::Transform(Module& module, TransformerContext& context, const Options& options, std::string* error) |
| 15 | + { |
| 16 | + assert(m_variableMappings.empty()); |
| 17 | + m_options = &options; |
| 18 | + |
| 19 | + m_currentModuleId = 0; |
| 20 | + for (auto& importedModule : module.importedModules) |
| 21 | + { |
| 22 | + if (!TransformModule(*importedModule.module, context, error)) |
| 23 | + return false; |
| 24 | + |
| 25 | + m_currentModuleId++; |
| 26 | + } |
| 27 | + |
| 28 | + m_currentModuleId = Nz::MaxValue(); |
| 29 | + return TransformModule(module, context, error); |
| 30 | + } |
| 31 | + |
| 32 | + auto LoopUnrollTransformer::Transform(IdentifierValueExpression&& expression) -> ExpressionTransformation |
| 33 | + { |
| 34 | + if (expression.identifierType != IdentifierType::Variable) |
| 35 | + return VisitChildren{}; |
| 36 | + |
| 37 | + for (const VariableRemapping& remapping : m_variableMappings) |
| 38 | + { |
| 39 | + if (expression.identifierIndex == remapping.sourceVariableIndex) |
| 40 | + { |
| 41 | + expression.identifierType = remapping.targetIdentifierType; |
| 42 | + expression.identifierIndex = remapping.targetIdentifierIndex; |
| 43 | + break; |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + return VisitChildren{}; |
| 48 | + } |
| 49 | + |
| 50 | + auto LoopUnrollTransformer::Transform(ForEachStatement&& forEachStatement) -> StatementTransformation |
| 51 | + { |
| 52 | + if (!m_options->unrollForEachLoops) |
| 53 | + return VisitChildren{}; |
| 54 | + |
| 55 | + if (!forEachStatement.unroll.HasValue() || forEachStatement.unroll.GetResultingValue() != LoopUnroll::Always) |
| 56 | + return VisitChildren{}; |
| 57 | + |
| 58 | + const ExpressionType* exprType = GetExpressionType(*forEachStatement.expression); |
| 59 | + if (!exprType) |
| 60 | + return VisitChildren{}; |
| 61 | + |
| 62 | + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); |
| 63 | + |
| 64 | + PushScope(); |
| 65 | + ClearFlags(TransformerFlag::IgnoreExpressions); |
| 66 | + |
| 67 | + // Repeat code |
| 68 | + auto multi = std::make_unique<MultiStatement>(); |
| 69 | + multi->sourceLocation = forEachStatement.sourceLocation; |
| 70 | + |
| 71 | + if (IsArrayType(resolvedExprType)) |
| 72 | + { |
| 73 | + const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType); |
| 74 | + |
| 75 | + ExpressionType constantType = UnwrapExternalType(arrayType.InnerType()); |
| 76 | + |
| 77 | + std::size_t mappingIndex = Nz::MaxValue(); |
| 78 | + if (forEachStatement.varIndex) |
| 79 | + { |
| 80 | + mappingIndex = m_variableMappings.size(); |
| 81 | + m_variableMappings.emplace_back(VariableRemapping{ |
| 82 | + IdentifierType::Variable, |
| 83 | + *forEachStatement.varIndex, |
| 84 | + 0 |
| 85 | + }); |
| 86 | + } |
| 87 | + |
| 88 | + for (std::uint32_t counter = 0; counter < arrayType.length; ++counter) |
| 89 | + { |
| 90 | + PushScope(); |
| 91 | + |
| 92 | + auto innerMulti = std::make_unique<MultiStatement>(); |
| 93 | + innerMulti->sourceLocation = forEachStatement.sourceLocation; |
| 94 | + |
| 95 | + auto constant = ShaderBuilder::ConstantValue(counter, forEachStatement.sourceLocation); |
| 96 | + |
| 97 | + ExpressionPtr accessIndex = ShaderBuilder::AccessIndex(Ast::Clone(*forEachStatement.expression), std::move(constant)); |
| 98 | + accessIndex->cachedExpressionType = constantType; |
| 99 | + accessIndex->sourceLocation = forEachStatement.sourceLocation; |
| 100 | + |
| 101 | + DeclareVariableStatementPtr elementVariable = ShaderBuilder::DeclareVariable(forEachStatement.varName, std::move(accessIndex)); |
| 102 | + elementVariable->sourceLocation = forEachStatement.sourceLocation; |
| 103 | + elementVariable->varIndex = m_context->variables.Register(TransformerContext::VariableData{ constantType }, {}, forEachStatement.sourceLocation); |
| 104 | + elementVariable->varType = constantType; |
| 105 | + |
| 106 | + if (mappingIndex != Nz::MaxValue<std::size_t>()) |
| 107 | + m_variableMappings[mappingIndex].targetIdentifierIndex = *elementVariable->varIndex; |
| 108 | + |
| 109 | + innerMulti->statements.emplace_back(std::move(elementVariable)); |
| 110 | + |
| 111 | + // Remap indices (as unrolling the loop will reuse them) |
| 112 | + IndexRemapperVisitor::Options indexCallbacks; |
| 113 | + indexCallbacks.indexGenerator = [this](IdentifierType identifierType, std::size_t /*previousIndex*/) |
| 114 | + { |
| 115 | + switch (identifierType) |
| 116 | + { |
| 117 | + case IdentifierType::Alias: return m_context->aliases.RegisterNewIndex(true); |
| 118 | + case IdentifierType::Constant: return m_context->constants.RegisterNewIndex(true); |
| 119 | + case IdentifierType::Function: return m_context->functions.RegisterNewIndex(true); |
| 120 | + case IdentifierType::Struct: return m_context->structs.RegisterNewIndex(true); |
| 121 | + case IdentifierType::Variable: return m_context->variables.RegisterNewIndex(true); |
| 122 | + |
| 123 | + default: |
| 124 | + throw std::runtime_error("unexpected identifier type"); |
| 125 | + } |
| 126 | + }; |
| 127 | + |
| 128 | + StatementPtr bodyStatement = Ast::Clone(*forEachStatement.statement); |
| 129 | + RemapIndices(bodyStatement, indexCallbacks); |
| 130 | + HandleStatement(bodyStatement); |
| 131 | + |
| 132 | + innerMulti->statements.emplace_back(Unscope(std::move(bodyStatement))); |
| 133 | + |
| 134 | + multi->statements.emplace_back(ShaderBuilder::Scoped(std::move(innerMulti))); |
| 135 | + |
| 136 | + PopScope(); |
| 137 | + } |
| 138 | + |
| 139 | + if (mappingIndex != Nz::MaxValue<std::size_t>()) |
| 140 | + { |
| 141 | + assert(m_variableMappings.size() == mappingIndex + 1); |
| 142 | + m_variableMappings.pop_back(); |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + SetFlags(TransformerFlag::IgnoreExpressions); |
| 147 | + PopScope(); |
| 148 | + |
| 149 | + return ReplaceStatement{ std::move(multi) }; |
| 150 | + } |
| 151 | + |
| 152 | + auto LoopUnrollTransformer::Transform(ForStatement&& forStatement) -> StatementTransformation |
| 153 | + { |
| 154 | + if (!m_options->unrollForLoops) |
| 155 | + return VisitChildren{}; |
| 156 | + |
| 157 | + if (!forStatement.unroll.HasValue() || forStatement.unroll.GetResultingValue() != LoopUnroll::Always) |
| 158 | + return VisitChildren{}; |
| 159 | + |
| 160 | + std::optional<ConstantValue> fromValue = ComputeConstantValue(forStatement.fromExpr); |
| 161 | + std::optional<ConstantValue> toValue = ComputeConstantValue(forStatement.toExpr); |
| 162 | + if (!fromValue.has_value() || !toValue.has_value()) |
| 163 | + return VisitChildren{}; |
| 164 | + |
| 165 | + std::optional<ConstantValue> stepValue; |
| 166 | + if (forStatement.stepExpr) |
| 167 | + { |
| 168 | + stepValue = ComputeConstantValue(forStatement.stepExpr); |
| 169 | + if (!stepValue.has_value()) |
| 170 | + return VisitChildren{}; |
| 171 | + } |
| 172 | + |
| 173 | + auto multi = std::make_unique<MultiStatement>(); |
| 174 | + multi->sourceLocation = forStatement.sourceLocation; |
| 175 | + |
| 176 | + auto Unroll = [&](auto dummy) |
| 177 | + { |
| 178 | + using T = std::decay_t<decltype(dummy)>; |
| 179 | + |
| 180 | + auto GetValue = [](const ConstantValue& constantValue, const SourceLocation& sourceLocation) -> T |
| 181 | + { |
| 182 | + if (const IntLiteral* literal = std::get_if<IntLiteral>(&constantValue)) |
| 183 | + { |
| 184 | + if constexpr (std::is_same_v<T, std::int32_t>) |
| 185 | + return LiteralToInt32(*literal, sourceLocation); |
| 186 | + else if constexpr (std::is_same_v<T, std::uint32_t>) |
| 187 | + return LiteralToUInt32(*literal, sourceLocation); |
| 188 | + } |
| 189 | + |
| 190 | + return std::get<T>(constantValue); |
| 191 | + }; |
| 192 | + |
| 193 | + T counter = GetValue(*fromValue, forStatement.fromExpr->sourceLocation); |
| 194 | + T to = GetValue(*toValue, forStatement.toExpr->sourceLocation); |
| 195 | + T step = (forStatement.stepExpr) ? GetValue(*stepValue, forStatement.stepExpr->sourceLocation) : T{ 1 }; |
| 196 | + |
| 197 | + std::size_t mappingIndex = Nz::MaxValue(); |
| 198 | + if (forStatement.varIndex) |
| 199 | + { |
| 200 | + mappingIndex = m_variableMappings.size(); |
| 201 | + m_variableMappings.emplace_back(VariableRemapping{ |
| 202 | + IdentifierType::Constant, |
| 203 | + *forStatement.varIndex, |
| 204 | + 0 |
| 205 | + }); |
| 206 | + } |
| 207 | + |
| 208 | + ClearFlags(TransformerFlag::IgnoreExpressions); |
| 209 | + |
| 210 | + for (; counter < to; counter += step) |
| 211 | + { |
| 212 | + PushScope(); |
| 213 | + |
| 214 | + auto innerMulti = std::make_unique<MultiStatement>(); |
| 215 | + innerMulti->sourceLocation = forStatement.sourceLocation; |
| 216 | + |
| 217 | + auto constant = ShaderBuilder::ConstantValue(counter, forStatement.sourceLocation); |
| 218 | + |
| 219 | + ExpressionValue<ExpressionType> constantType; |
| 220 | + if (constant->cachedExpressionType) |
| 221 | + constantType = *constant->cachedExpressionType; |
| 222 | + |
| 223 | + DeclareConstStatementPtr constDecl = ShaderBuilder::DeclareConst(forStatement.varName, std::move(constantType), std::move(constant)); |
| 224 | + constDecl->constIndex = m_context->constants.Register(TransformerContext::ConstantData{ m_currentModuleId, counter }, {}, forStatement.sourceLocation); |
| 225 | + constDecl->sourceLocation = forStatement.sourceLocation; |
| 226 | + |
| 227 | + if (mappingIndex != Nz::MaxValue<std::size_t>()) |
| 228 | + m_variableMappings[mappingIndex].targetIdentifierIndex = *constDecl->constIndex; |
| 229 | + |
| 230 | + innerMulti->statements.emplace_back(std::move(constDecl)); |
| 231 | + |
| 232 | + // Remap indices (as unrolling the loop will reuse them) |
| 233 | + IndexRemapperVisitor::Options indexCallbacks; |
| 234 | + indexCallbacks.indexGenerator = [this](IdentifierType identifierType, std::size_t /*previousIndex*/) |
| 235 | + { |
| 236 | + switch (identifierType) |
| 237 | + { |
| 238 | + case IdentifierType::Alias: return m_context->aliases.RegisterNewIndex(true); |
| 239 | + case IdentifierType::Constant: return m_context->constants.RegisterNewIndex(true); |
| 240 | + case IdentifierType::Function: return m_context->functions.RegisterNewIndex(true); |
| 241 | + case IdentifierType::Struct: return m_context->structs.RegisterNewIndex(true); |
| 242 | + case IdentifierType::Variable: return m_context->variables.RegisterNewIndex(true); |
| 243 | + |
| 244 | + default: |
| 245 | + throw std::runtime_error("unexpected identifier type"); |
| 246 | + } |
| 247 | + }; |
| 248 | + |
| 249 | + StatementPtr bodyStatement = Ast::Clone(*forStatement.statement); |
| 250 | + RemapIndices(bodyStatement, indexCallbacks); |
| 251 | + HandleStatement(bodyStatement); |
| 252 | + |
| 253 | + innerMulti->statements.emplace_back(Unscope(std::move(bodyStatement))); |
| 254 | + |
| 255 | + multi->statements.emplace_back(ShaderBuilder::Scoped(std::move(innerMulti))); |
| 256 | + |
| 257 | + PopScope(); |
| 258 | + } |
| 259 | + |
| 260 | + SetFlags(TransformerFlag::IgnoreExpressions); |
| 261 | + |
| 262 | + if (mappingIndex != Nz::MaxValue<std::size_t>()) |
| 263 | + { |
| 264 | + assert(m_variableMappings.size() == mappingIndex + 1); |
| 265 | + m_variableMappings.pop_back(); |
| 266 | + } |
| 267 | + }; |
| 268 | + |
| 269 | + ExpressionType fromExprType = GetConstantType(*fromValue); |
| 270 | + if (!IsPrimitiveType(fromExprType)) |
| 271 | + throw CompilerForFromTypeExpectIntegerTypeError{ forStatement.fromExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation) }; |
| 272 | + |
| 273 | + PrimitiveType counterType = std::get<PrimitiveType>(fromExprType); |
| 274 | + if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32 && counterType != PrimitiveType::IntLiteral) |
| 275 | + throw CompilerForFromTypeExpectIntegerTypeError{ forStatement.fromExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation) }; |
| 276 | + |
| 277 | + if (counterType == PrimitiveType::IntLiteral) |
| 278 | + { |
| 279 | + // Fallback to "to" type |
| 280 | + ExpressionType toExprType = GetConstantType(*toValue); |
| 281 | + if (!IsPrimitiveType(toExprType)) |
| 282 | + throw CompilerForToUnmatchingTypeError{ forStatement.toExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation), ToString(toExprType, forStatement.toExpr->sourceLocation) }; |
| 283 | + |
| 284 | + PrimitiveType toCounterType = std::get<PrimitiveType>(toExprType); |
| 285 | + if (toCounterType != PrimitiveType::Int32 && toCounterType != PrimitiveType::UInt32 && toCounterType != PrimitiveType::IntLiteral) |
| 286 | + throw CompilerForToUnmatchingTypeError{ forStatement.toExpr->sourceLocation, ToString(fromExprType, forStatement.fromExpr->sourceLocation), ToString(toExprType, forStatement.toExpr->sourceLocation) }; |
| 287 | + |
| 288 | + counterType = toCounterType; |
| 289 | + } |
| 290 | + |
| 291 | + if (counterType == PrimitiveType::IntLiteral) |
| 292 | + counterType = PrimitiveType::Int32; |
| 293 | + |
| 294 | + switch (counterType) |
| 295 | + { |
| 296 | + case PrimitiveType::Int32: |
| 297 | + Unroll(std::int32_t{}); |
| 298 | + break; |
| 299 | + |
| 300 | + case PrimitiveType::UInt32: |
| 301 | + Unroll(std::uint32_t{}); |
| 302 | + break; |
| 303 | + |
| 304 | + default: |
| 305 | + throw AstInternalError{ forStatement.sourceLocation, "unexpected counter type " + ToString(counterType, forStatement.fromExpr->sourceLocation) }; |
| 306 | + } |
| 307 | + |
| 308 | + return ReplaceStatement{ std::move(multi) }; |
| 309 | + } |
| 310 | +} |
0 commit comments