|
63 | 63 | #include "flang/Semantics/tools.h" |
64 | 64 | #include "flang/Support/Version.h" |
65 | 65 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 66 | +#include "mlir/IR/BuiltinAttributes.h" |
66 | 67 | #include "mlir/IR/Matchers.h" |
67 | 68 | #include "mlir/IR/PatternMatch.h" |
68 | 69 | #include "mlir/Parser/Parser.h" |
@@ -2170,32 +2171,54 @@ class FirConverter : public Fortran::lower::AbstractConverter { |
2170 | 2171 | return builder->createIntegerConstant(loc, controlType, 1); // step |
2171 | 2172 | } |
2172 | 2173 |
|
| 2174 | + // For unroll directives without a value, force full unrolling. |
| 2175 | + // For unroll directives with a value, if the value is greater than 1, |
| 2176 | + // force unrolling with the given factor. Otherwise, disable unrolling. |
| 2177 | + mlir::LLVM::LoopUnrollAttr |
| 2178 | + genLoopUnrollAttr(std::optional<std::uint64_t> directiveArg) { |
| 2179 | + mlir::BoolAttr falseAttr = |
| 2180 | + mlir::BoolAttr::get(builder->getContext(), false); |
| 2181 | + mlir::BoolAttr trueAttr = mlir::BoolAttr::get(builder->getContext(), true); |
| 2182 | + mlir::IntegerAttr countAttr; |
| 2183 | + mlir::BoolAttr fullUnrollAttr; |
| 2184 | + bool shouldUnroll = true; |
| 2185 | + if (directiveArg.has_value()) { |
| 2186 | + auto unrollingFactor = directiveArg.value(); |
| 2187 | + if (unrollingFactor == 0 || unrollingFactor == 1) { |
| 2188 | + shouldUnroll = false; |
| 2189 | + } else { |
| 2190 | + countAttr = |
| 2191 | + builder->getIntegerAttr(builder->getI64Type(), unrollingFactor); |
| 2192 | + } |
| 2193 | + } else { |
| 2194 | + fullUnrollAttr = trueAttr; |
| 2195 | + } |
| 2196 | + |
| 2197 | + mlir::BoolAttr disableAttr = shouldUnroll ? falseAttr : trueAttr; |
| 2198 | + return mlir::LLVM::LoopUnrollAttr::get( |
| 2199 | + builder->getContext(), /*disable=*/disableAttr, /*count=*/countAttr, {}, |
| 2200 | + /*full=*/fullUnrollAttr, {}, {}, {}); |
| 2201 | + } |
| 2202 | + |
2173 | 2203 | void addLoopAnnotationAttr( |
2174 | 2204 | IncrementLoopInfo &info, |
2175 | 2205 | llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) { |
2176 | | - mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false); |
2177 | | - mlir::BoolAttr t = mlir::BoolAttr::get(builder->getContext(), true); |
2178 | 2206 | mlir::LLVM::LoopVectorizeAttr va; |
2179 | 2207 | mlir::LLVM::LoopUnrollAttr ua; |
2180 | 2208 | bool has_attrs = false; |
2181 | 2209 | for (const auto *dir : dirs) { |
2182 | 2210 | Fortran::common::visit( |
2183 | 2211 | Fortran::common::visitors{ |
2184 | 2212 | [&](const Fortran::parser::CompilerDirective::VectorAlways &) { |
| 2213 | + mlir::BoolAttr falseAttr = |
| 2214 | + mlir::BoolAttr::get(builder->getContext(), false); |
2185 | 2215 | va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(), |
2186 | | - /*disable=*/f, {}, {}, |
2187 | | - {}, {}, {}, {}); |
| 2216 | + /*disable=*/falseAttr, |
| 2217 | + {}, {}, {}, {}, {}, {}); |
2188 | 2218 | has_attrs = true; |
2189 | 2219 | }, |
2190 | 2220 | [&](const Fortran::parser::CompilerDirective::Unroll &u) { |
2191 | | - mlir::IntegerAttr countAttr; |
2192 | | - if (u.v.has_value()) { |
2193 | | - countAttr = builder->getIntegerAttr(builder->getI64Type(), |
2194 | | - u.v.value()); |
2195 | | - } |
2196 | | - ua = mlir::LLVM::LoopUnrollAttr::get( |
2197 | | - builder->getContext(), /*disable=*/f, /*count*/ countAttr, |
2198 | | - {}, /*full*/ u.v.has_value() ? f : t, {}, {}, {}); |
| 2221 | + ua = genLoopUnrollAttr(u.v); |
2199 | 2222 | has_attrs = true; |
2200 | 2223 | }, |
2201 | 2224 | [&](const auto &) {}}, |
|
0 commit comments