@@ -2128,6 +2128,161 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
21282128 return loopOp;
21292129}
21302130
2131+ static mlir::omp::CanonicalLoopOp
2132+ genCanonicalLoopOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2133+ semantics::SemanticsContext &semaCtx,
2134+ lower::pft::Evaluation &eval, mlir::Location loc,
2135+ const ConstructQueue &queue,
2136+ ConstructQueue::const_iterator item,
2137+ llvm::ArrayRef<const semantics::Symbol *> ivs,
2138+ llvm::omp::Directive directive, DataSharingProcessor &dsp) {
2139+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2140+
2141+ assert (ivs.size () == 1 && " Nested loops not yet implemented" );
2142+ const semantics::Symbol *iv = ivs[0 ];
2143+
2144+ auto &nestedEval = eval.getFirstNestedEvaluation ();
2145+ if (nestedEval.getIf <parser::DoConstruct>()->IsDoConcurrent ()) {
2146+ TODO (loc, " Do Concurrent in unroll construct" );
2147+ }
2148+
2149+ // Get the loop bounds (and increment)
2150+ auto &doLoopEval = nestedEval.getFirstNestedEvaluation ();
2151+ auto *doStmt = doLoopEval.getIf <parser::NonLabelDoStmt>();
2152+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
2153+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t );
2154+ assert (loopControl.has_value ());
2155+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u );
2156+ assert (bounds && " Expected bounds for canonical loop" );
2157+ lower::StatementContext stmtCtx;
2158+ mlir::Value loopLBVar = fir::getBase (
2159+ converter.genExprValue (*semantics::GetExpr (bounds->lower ), stmtCtx));
2160+ mlir::Value loopUBVar = fir::getBase (
2161+ converter.genExprValue (*semantics::GetExpr (bounds->upper ), stmtCtx));
2162+ mlir::Value loopStepVar = [&]() {
2163+ if (bounds->step ) {
2164+ return fir::getBase (
2165+ converter.genExprValue (*semantics::GetExpr (bounds->step ), stmtCtx));
2166+ } else {
2167+ // If `step` is not present, assume it is `1`.
2168+ return firOpBuilder.createIntegerConstant (loc, firOpBuilder.getI32Type (),
2169+ 1 );
2170+ }
2171+ }();
2172+
2173+ // Get the integer kind for the loop variable and cast the loop bounds
2174+ size_t loopVarTypeSize = bounds->name .thing .symbol ->GetUltimate ().size ();
2175+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2176+ loopLBVar = firOpBuilder.createConvert (loc, loopVarType, loopLBVar);
2177+ loopUBVar = firOpBuilder.createConvert (loc, loopVarType, loopUBVar);
2178+ loopStepVar = firOpBuilder.createConvert (loc, loopVarType, loopStepVar);
2179+
2180+ // Start lowering
2181+ mlir::Value zero = firOpBuilder.createIntegerConstant (loc, loopVarType, 0 );
2182+ mlir::Value one = firOpBuilder.createIntegerConstant (loc, loopVarType, 1 );
2183+ mlir::Value isDownwards = firOpBuilder.create <mlir::arith::CmpIOp>(
2184+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
2185+
2186+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
2187+ mlir::Value negStep =
2188+ firOpBuilder.create <mlir::arith::SubIOp>(loc, zero, loopStepVar);
2189+ mlir::Value incr = firOpBuilder.create <mlir::arith::SelectOp>(
2190+ loc, isDownwards, negStep, loopStepVar);
2191+ mlir::Value lb = firOpBuilder.create <mlir::arith::SelectOp>(
2192+ loc, isDownwards, loopUBVar, loopLBVar);
2193+ mlir::Value ub = firOpBuilder.create <mlir::arith::SelectOp>(
2194+ loc, isDownwards, loopLBVar, loopUBVar);
2195+
2196+ // Compute the trip count assuming lb <= ub. This guarantees that the result
2197+ // is non-negative and we can use unsigned arithmetic.
2198+ mlir::Value span = firOpBuilder.create <mlir::arith::SubIOp>(
2199+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
2200+ mlir::Value tcMinusOne =
2201+ firOpBuilder.create <mlir::arith::DivUIOp>(loc, span, incr);
2202+ mlir::Value tcIfLooping = firOpBuilder.create <mlir::arith::AddIOp>(
2203+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
2204+
2205+ // Fall back to 0 if lb > ub
2206+ mlir::Value isZeroTC = firOpBuilder.create <mlir::arith::CmpIOp>(
2207+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
2208+ mlir::Value tripcount = firOpBuilder.create <mlir::arith::SelectOp>(
2209+ loc, isZeroTC, zero, tcIfLooping);
2210+
2211+ // Create the CLI handle.
2212+ auto newcli = firOpBuilder.create <mlir::omp::NewCliOp>(loc);
2213+ mlir::Value cli = newcli.getResult ();
2214+
2215+ auto ivCallback = [&](mlir::Operation *op)
2216+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2217+ mlir::Region ®ion = op->getRegion (0 );
2218+
2219+ // Create the op's region skeleton (BB taking the iv as argument)
2220+ firOpBuilder.createBlock (®ion, {}, {loopVarType}, {loc});
2221+
2222+ // Compute the value of the loop variable from the logical iteration number.
2223+ mlir::Value natIterNum = fir::getBase (region.front ().getArgument (0 ));
2224+ mlir::Value scaled =
2225+ firOpBuilder.create <mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
2226+ mlir::Value userVal =
2227+ firOpBuilder.create <mlir::arith::AddIOp>(loc, loopLBVar, scaled);
2228+
2229+ // The argument is not currently in memory, so make a temporary for the
2230+ // argument, and store it there, then bind that location to the argument.
2231+ mlir::Operation *storeOp =
2232+ createAndSetPrivatizedLoopVar (converter, loc, userVal, iv);
2233+
2234+ firOpBuilder.setInsertionPointAfter (storeOp);
2235+ return {iv};
2236+ };
2237+
2238+ // Create the omp.canonical_loop operation
2239+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
2240+ OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
2241+ directive)
2242+ .setClauses (&item->clauses )
2243+ .setDataSharingProcessor (&dsp)
2244+ .setGenRegionEntryCb (ivCallback),
2245+ queue, item, tripcount, cli);
2246+
2247+ firOpBuilder.setInsertionPointAfter (canonLoop);
2248+ return canonLoop;
2249+ }
2250+
2251+ static void genUnrollOp (Fortran::lower::AbstractConverter &converter,
2252+ Fortran::lower::SymMap &symTable,
2253+ lower::StatementContext &stmtCtx,
2254+ Fortran::semantics::SemanticsContext &semaCtx,
2255+ Fortran::lower::pft::Evaluation &eval,
2256+ mlir::Location loc, const ConstructQueue &queue,
2257+ ConstructQueue::const_iterator item) {
2258+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2259+
2260+ mlir::omp::LoopRelatedClauseOps loopInfo;
2261+ llvm::SmallVector<const semantics::Symbol *> iv;
2262+ collectLoopRelatedInfo (converter, loc, eval, item->clauses , loopInfo, iv);
2263+
2264+ // Clauses for unrolling not yet implemnted
2265+ ClauseProcessor cp (converter, semaCtx, item->clauses );
2266+ cp.processTODO <clause::Partial, clause::Full>(
2267+ loc, llvm::omp::Directive::OMPD_unroll);
2268+
2269+ // Even though unroll does not support data-sharing clauses, but this is
2270+ // required to fill the symbol table.
2271+ DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
2272+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2273+ /* useDelayedPrivatization=*/ false , symTable);
2274+ dsp.processStep1 ();
2275+
2276+ // Emit the associated loop
2277+ auto canonLoop =
2278+ genCanonicalLoopOp (converter, symTable, semaCtx, eval, loc, queue, item,
2279+ iv, llvm::omp::Directive::OMPD_unroll, dsp);
2280+
2281+ // Apply unrolling to it
2282+ auto cli = canonLoop.getCli ();
2283+ firOpBuilder.create <mlir::omp::UnrollHeuristicOp>(loc, cli);
2284+ }
2285+
21312286static mlir::omp::MaskedOp
21322287genMaskedOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
21332288 lower::StatementContext &stmtCtx,
@@ -3516,12 +3671,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
35163671 newOp = genTeamsOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
35173672 item);
35183673 break ;
3519- case llvm::omp::Directive::OMPD_tile:
3520- case llvm::omp::Directive::OMPD_unroll: {
3521- unsigned version = semaCtx.langOptions ().OpenMPVersion ;
3522- TODO (loc, " Unhandled loop directive (" +
3523- llvm::omp::getOpenMPDirectiveName (dir, version) + " )" );
3524- }
3674+ case llvm::omp::Directive::OMPD_unroll:
3675+ genUnrollOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
3676+ break ;
35253677 // case llvm::omp::Directive::OMPD_workdistribute:
35263678 case llvm::omp::Directive::OMPD_workshare:
35273679 newOp = genWorkshareOp (converter, symTable, stmtCtx, semaCtx, eval, loc,
0 commit comments