@@ -63,6 +63,28 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
6363 lower::pft::Evaluation &eval,
6464 mlir::Location loc);
6565
66+ static llvm::omp::Directive
67+ getOpenMPDirectiveEnum (const parser::OmpLoopDirective &beginStatment) {
68+ return beginStatment.v ;
69+ }
70+
71+ static llvm::omp::Directive getOpenMPDirectiveEnum (
72+ const parser::OmpBeginLoopDirective &beginLoopDirective) {
73+ return getOpenMPDirectiveEnum (
74+ std::get<parser::OmpLoopDirective>(beginLoopDirective.t ));
75+ }
76+
77+ static llvm::omp::Directive
78+ getOpenMPDirectiveEnum (const parser::OpenMPLoopConstruct &ompLoopConstruct) {
79+ return getOpenMPDirectiveEnum (
80+ std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t ));
81+ }
82+
83+ static llvm::omp::Directive getOpenMPDirectiveEnum (
84+ const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
85+ return getOpenMPDirectiveEnum (ompLoopConstruct.value ());
86+ }
87+
6688namespace {
6789// / Structure holding information that is needed to pass host-evaluated
6890// / information to later lowering stages.
@@ -2069,6 +2091,163 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
20692091 return loopOp;
20702092}
20712093
2094+ static mlir::omp::CanonicalLoopOp
2095+ genCanonicalLoopOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2096+ semantics::SemanticsContext &semaCtx,
2097+ lower::pft::Evaluation &eval, mlir::Location loc,
2098+ const ConstructQueue &queue,
2099+ ConstructQueue::const_iterator item,
2100+ llvm::ArrayRef<const semantics::Symbol *> ivs,
2101+ llvm::omp::Directive directive, DataSharingProcessor &dsp) {
2102+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2103+
2104+ assert (ivs.size () == 1 && " Nested loops not yet implemented" );
2105+ const semantics::Symbol *iv = ivs[0 ];
2106+
2107+ auto &nestedEval = eval.getFirstNestedEvaluation ();
2108+ if (nestedEval.getIf <parser::DoConstruct>()->IsDoConcurrent ()) {
2109+ // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will
2110+ // need to add special cases for this combination.
2111+ TODO (loc, " DO CONCURRENT as canonical loop not supported" );
2112+ }
2113+
2114+ // Get the loop bounds (and increment)
2115+ auto &doLoopEval = nestedEval.getFirstNestedEvaluation ();
2116+ auto *doStmt = doLoopEval.getIf <parser::NonLabelDoStmt>();
2117+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
2118+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t );
2119+ assert (loopControl.has_value ());
2120+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u );
2121+ assert (bounds && " Expected bounds for canonical loop" );
2122+ lower::StatementContext stmtCtx;
2123+ mlir::Value loopLBVar = fir::getBase (
2124+ converter.genExprValue (*semantics::GetExpr (bounds->lower ), stmtCtx));
2125+ mlir::Value loopUBVar = fir::getBase (
2126+ converter.genExprValue (*semantics::GetExpr (bounds->upper ), stmtCtx));
2127+ mlir::Value loopStepVar = [&]() {
2128+ if (bounds->step ) {
2129+ return fir::getBase (
2130+ converter.genExprValue (*semantics::GetExpr (bounds->step ), stmtCtx));
2131+ }
2132+
2133+ // If `step` is not present, assume it is `1`.
2134+ return firOpBuilder.createIntegerConstant (loc, firOpBuilder.getI32Type (),
2135+ 1 );
2136+ }();
2137+
2138+ // Get the integer kind for the loop variable and cast the loop bounds
2139+ size_t loopVarTypeSize = bounds->name .thing .symbol ->GetUltimate ().size ();
2140+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2141+ loopLBVar = firOpBuilder.createConvert (loc, loopVarType, loopLBVar);
2142+ loopUBVar = firOpBuilder.createConvert (loc, loopVarType, loopUBVar);
2143+ loopStepVar = firOpBuilder.createConvert (loc, loopVarType, loopStepVar);
2144+
2145+ // Start lowering
2146+ mlir::Value zero = firOpBuilder.createIntegerConstant (loc, loopVarType, 0 );
2147+ mlir::Value one = firOpBuilder.createIntegerConstant (loc, loopVarType, 1 );
2148+ mlir::Value isDownwards = firOpBuilder.create <mlir::arith::CmpIOp>(
2149+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
2150+
2151+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
2152+ mlir::Value negStep =
2153+ firOpBuilder.create <mlir::arith::SubIOp>(loc, zero, loopStepVar);
2154+ mlir::Value incr = firOpBuilder.create <mlir::arith::SelectOp>(
2155+ loc, isDownwards, negStep, loopStepVar);
2156+ mlir::Value lb = firOpBuilder.create <mlir::arith::SelectOp>(
2157+ loc, isDownwards, loopUBVar, loopLBVar);
2158+ mlir::Value ub = firOpBuilder.create <mlir::arith::SelectOp>(
2159+ loc, isDownwards, loopLBVar, loopUBVar);
2160+
2161+ // Compute the trip count assuming lb <= ub. This guarantees that the result
2162+ // is non-negative and we can use unsigned arithmetic.
2163+ mlir::Value span = firOpBuilder.create <mlir::arith::SubIOp>(
2164+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
2165+ mlir::Value tcMinusOne =
2166+ firOpBuilder.create <mlir::arith::DivUIOp>(loc, span, incr);
2167+ mlir::Value tcIfLooping = firOpBuilder.create <mlir::arith::AddIOp>(
2168+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
2169+
2170+ // Fall back to 0 if lb > ub
2171+ mlir::Value isZeroTC = firOpBuilder.create <mlir::arith::CmpIOp>(
2172+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
2173+ mlir::Value tripcount = firOpBuilder.create <mlir::arith::SelectOp>(
2174+ loc, isZeroTC, zero, tcIfLooping);
2175+
2176+ // Create the CLI handle.
2177+ auto newcli = firOpBuilder.create <mlir::omp::NewCliOp>(loc);
2178+ mlir::Value cli = newcli.getResult ();
2179+
2180+ auto ivCallback = [&](mlir::Operation *op)
2181+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2182+ mlir::Region ®ion = op->getRegion (0 );
2183+
2184+ // Create the op's region skeleton (BB taking the iv as argument)
2185+ firOpBuilder.createBlock (®ion, {}, {loopVarType}, {loc});
2186+
2187+ // Compute the value of the loop variable from the logical iteration number.
2188+ mlir::Value natIterNum = fir::getBase (region.front ().getArgument (0 ));
2189+ mlir::Value scaled =
2190+ firOpBuilder.create <mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
2191+ mlir::Value userVal =
2192+ firOpBuilder.create <mlir::arith::AddIOp>(loc, loopLBVar, scaled);
2193+
2194+ // The argument is not currently in memory, so make a temporary for the
2195+ // argument, and store it there, then bind that location to the argument.
2196+ mlir::Operation *storeOp =
2197+ createAndSetPrivatizedLoopVar (converter, loc, userVal, iv);
2198+
2199+ firOpBuilder.setInsertionPointAfter (storeOp);
2200+ return {iv};
2201+ };
2202+
2203+ // Create the omp.canonical_loop operation
2204+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
2205+ OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
2206+ directive)
2207+ .setClauses (&item->clauses )
2208+ .setDataSharingProcessor (&dsp)
2209+ .setGenRegionEntryCb (ivCallback),
2210+ queue, item, tripcount, cli);
2211+
2212+ firOpBuilder.setInsertionPointAfter (canonLoop);
2213+ return canonLoop;
2214+ }
2215+
2216+ static void genUnrollOp (Fortran::lower::AbstractConverter &converter,
2217+ Fortran::lower::SymMap &symTable,
2218+ lower::StatementContext &stmtCtx,
2219+ Fortran::semantics::SemanticsContext &semaCtx,
2220+ Fortran::lower::pft::Evaluation &eval,
2221+ mlir::Location loc, const ConstructQueue &queue,
2222+ ConstructQueue::const_iterator item) {
2223+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2224+
2225+ mlir::omp::LoopRelatedClauseOps loopInfo;
2226+ llvm::SmallVector<const semantics::Symbol *> iv;
2227+ collectLoopRelatedInfo (converter, loc, eval, item->clauses , loopInfo, iv);
2228+
2229+ // Clauses for unrolling not yet implemnted
2230+ ClauseProcessor cp (converter, semaCtx, item->clauses );
2231+ cp.processTODO <clause::Partial, clause::Full>(
2232+ loc, llvm::omp::Directive::OMPD_unroll);
2233+
2234+ // Even though unroll does not support data-sharing clauses, but this is
2235+ // required to fill the symbol table.
2236+ DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
2237+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2238+ /* useDelayedPrivatization=*/ false , symTable);
2239+ dsp.processStep1 ();
2240+
2241+ // Emit the associated loop
2242+ auto canonLoop =
2243+ genCanonicalLoopOp (converter, symTable, semaCtx, eval, loc, queue, item,
2244+ iv, llvm::omp::Directive::OMPD_unroll, dsp);
2245+
2246+ // Apply unrolling to it
2247+ auto cli = canonLoop.getCli ();
2248+ firOpBuilder.create <mlir::omp::UnrollHeuristicOp>(loc, cli);
2249+ }
2250+
20722251static mlir::omp::MaskedOp
20732252genMaskedOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
20742253 lower::StatementContext &stmtCtx,
@@ -3249,12 +3428,14 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
32493428 newOp = genTeamsOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
32503429 item);
32513430 break ;
3252- case llvm::omp::Directive::OMPD_tile:
3253- case llvm::omp::Directive::OMPD_unroll: {
3431+ case llvm::omp::Directive::OMPD_tile: {
32543432 unsigned version = semaCtx.langOptions ().OpenMPVersion ;
32553433 TODO (loc, " Unhandled loop directive (" +
32563434 llvm::omp::getOpenMPDirectiveName (dir, version) + " )" );
32573435 }
3436+ case llvm::omp::Directive::OMPD_unroll:
3437+ genUnrollOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
3438+ break ;
32583439 // case llvm::omp::Directive::OMPD_workdistribute:
32593440 case llvm::omp::Directive::OMPD_workshare:
32603441 newOp = genWorkshareOp (converter, symTable, stmtCtx, semaCtx, eval, loc,
@@ -3690,12 +3871,25 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
36903871 if (auto *ompNestedLoopCons{
36913872 std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
36923873 &*optLoopCons)}) {
3693- genOMP (converter, symTable, semaCtx, eval, ompNestedLoopCons->value ());
3874+ llvm::omp::Directive nestedDirective =
3875+ getOpenMPDirectiveEnum (*ompNestedLoopCons);
3876+ switch (nestedDirective) {
3877+ case llvm::omp::Directive::OMPD_tile:
3878+ // Emit the omp.loop_nest with annotation for tiling
3879+ genOMP (converter, symTable, semaCtx, eval, ompNestedLoopCons->value ());
3880+ break ;
3881+ default : {
3882+ unsigned version = semaCtx.langOptions ().OpenMPVersion ;
3883+ TODO (currentLocation,
3884+ " Applying a loop-associated on the loop generated by the " +
3885+ llvm::omp::getOpenMPDirectiveName (nestedDirective, version) +
3886+ " construct" );
3887+ }
3888+ }
36943889 }
36953890 }
36963891
3697- llvm::omp::Directive directive =
3698- std::get<parser::OmpLoopDirective>(beginLoopDirective.t ).v ;
3892+ llvm::omp::Directive directive = getOpenMPDirectiveEnum (beginLoopDirective);
36993893 const parser::CharBlock &source =
37003894 std::get<parser::OmpLoopDirective>(beginLoopDirective.t ).source ;
37013895 ConstructQueue queue{
0 commit comments