@@ -63,6 +63,28 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
63
63
lower::pft::Evaluation &eval,
64
64
mlir::Location loc);
65
65
66
+ static llvm::omp::Directive
67
+ getOpenMPDirectiveEnum (const parser::OmpLoopDirective &beginStatment) {
68
+ return beginStatment.v ;
69
+ }
70
+
71
+ static llvm::omp::Directive getOpenMPDirectiveEnum (
72
+ const parser::OmpBeginLoopDirective &beginLoopDirective) {
73
+ return getOpenMPDirectiveEnum (
74
+ std::get<parser::OmpLoopDirective>(beginLoopDirective.t ));
75
+ }
76
+
77
+ static llvm::omp::Directive
78
+ getOpenMPDirectiveEnum (const parser::OpenMPLoopConstruct &ompLoopConstruct) {
79
+ return getOpenMPDirectiveEnum (
80
+ std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t ));
81
+ }
82
+
83
+ static llvm::omp::Directive getOpenMPDirectiveEnum (
84
+ const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) {
85
+ return getOpenMPDirectiveEnum (ompLoopConstruct.value ());
86
+ }
87
+
66
88
namespace {
67
89
// / Structure holding information that is needed to pass host-evaluated
68
90
// / information to later lowering stages.
@@ -2069,6 +2091,163 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2069
2091
return loopOp;
2070
2092
}
2071
2093
2094
+ static mlir::omp::CanonicalLoopOp
2095
+ genCanonicalLoopOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2096
+ semantics::SemanticsContext &semaCtx,
2097
+ lower::pft::Evaluation &eval, mlir::Location loc,
2098
+ const ConstructQueue &queue,
2099
+ ConstructQueue::const_iterator item,
2100
+ llvm::ArrayRef<const semantics::Symbol *> ivs,
2101
+ llvm::omp::Directive directive, DataSharingProcessor &dsp) {
2102
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2103
+
2104
+ assert (ivs.size () == 1 && " Nested loops not yet implemented" );
2105
+ const semantics::Symbol *iv = ivs[0 ];
2106
+
2107
+ auto &nestedEval = eval.getFirstNestedEvaluation ();
2108
+ if (nestedEval.getIf <parser::DoConstruct>()->IsDoConcurrent ()) {
2109
+ // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will
2110
+ // need to add special cases for this combination.
2111
+ TODO (loc, " DO CONCURRENT as canonical loop not supported" );
2112
+ }
2113
+
2114
+ // Get the loop bounds (and increment)
2115
+ auto &doLoopEval = nestedEval.getFirstNestedEvaluation ();
2116
+ auto *doStmt = doLoopEval.getIf <parser::NonLabelDoStmt>();
2117
+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
2118
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t );
2119
+ assert (loopControl.has_value ());
2120
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u );
2121
+ assert (bounds && " Expected bounds for canonical loop" );
2122
+ lower::StatementContext stmtCtx;
2123
+ mlir::Value loopLBVar = fir::getBase (
2124
+ converter.genExprValue (*semantics::GetExpr (bounds->lower ), stmtCtx));
2125
+ mlir::Value loopUBVar = fir::getBase (
2126
+ converter.genExprValue (*semantics::GetExpr (bounds->upper ), stmtCtx));
2127
+ mlir::Value loopStepVar = [&]() {
2128
+ if (bounds->step ) {
2129
+ return fir::getBase (
2130
+ converter.genExprValue (*semantics::GetExpr (bounds->step ), stmtCtx));
2131
+ }
2132
+
2133
+ // If `step` is not present, assume it is `1`.
2134
+ return firOpBuilder.createIntegerConstant (loc, firOpBuilder.getI32Type (),
2135
+ 1 );
2136
+ }();
2137
+
2138
+ // Get the integer kind for the loop variable and cast the loop bounds
2139
+ size_t loopVarTypeSize = bounds->name .thing .symbol ->GetUltimate ().size ();
2140
+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2141
+ loopLBVar = firOpBuilder.createConvert (loc, loopVarType, loopLBVar);
2142
+ loopUBVar = firOpBuilder.createConvert (loc, loopVarType, loopUBVar);
2143
+ loopStepVar = firOpBuilder.createConvert (loc, loopVarType, loopStepVar);
2144
+
2145
+ // Start lowering
2146
+ mlir::Value zero = firOpBuilder.createIntegerConstant (loc, loopVarType, 0 );
2147
+ mlir::Value one = firOpBuilder.createIntegerConstant (loc, loopVarType, 1 );
2148
+ mlir::Value isDownwards = firOpBuilder.create <mlir::arith::CmpIOp>(
2149
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
2150
+
2151
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
2152
+ mlir::Value negStep =
2153
+ firOpBuilder.create <mlir::arith::SubIOp>(loc, zero, loopStepVar);
2154
+ mlir::Value incr = firOpBuilder.create <mlir::arith::SelectOp>(
2155
+ loc, isDownwards, negStep, loopStepVar);
2156
+ mlir::Value lb = firOpBuilder.create <mlir::arith::SelectOp>(
2157
+ loc, isDownwards, loopUBVar, loopLBVar);
2158
+ mlir::Value ub = firOpBuilder.create <mlir::arith::SelectOp>(
2159
+ loc, isDownwards, loopLBVar, loopUBVar);
2160
+
2161
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
2162
+ // is non-negative and we can use unsigned arithmetic.
2163
+ mlir::Value span = firOpBuilder.create <mlir::arith::SubIOp>(
2164
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
2165
+ mlir::Value tcMinusOne =
2166
+ firOpBuilder.create <mlir::arith::DivUIOp>(loc, span, incr);
2167
+ mlir::Value tcIfLooping = firOpBuilder.create <mlir::arith::AddIOp>(
2168
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
2169
+
2170
+ // Fall back to 0 if lb > ub
2171
+ mlir::Value isZeroTC = firOpBuilder.create <mlir::arith::CmpIOp>(
2172
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
2173
+ mlir::Value tripcount = firOpBuilder.create <mlir::arith::SelectOp>(
2174
+ loc, isZeroTC, zero, tcIfLooping);
2175
+
2176
+ // Create the CLI handle.
2177
+ auto newcli = firOpBuilder.create <mlir::omp::NewCliOp>(loc);
2178
+ mlir::Value cli = newcli.getResult ();
2179
+
2180
+ auto ivCallback = [&](mlir::Operation *op)
2181
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2182
+ mlir::Region ®ion = op->getRegion (0 );
2183
+
2184
+ // Create the op's region skeleton (BB taking the iv as argument)
2185
+ firOpBuilder.createBlock (®ion, {}, {loopVarType}, {loc});
2186
+
2187
+ // Compute the value of the loop variable from the logical iteration number.
2188
+ mlir::Value natIterNum = fir::getBase (region.front ().getArgument (0 ));
2189
+ mlir::Value scaled =
2190
+ firOpBuilder.create <mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
2191
+ mlir::Value userVal =
2192
+ firOpBuilder.create <mlir::arith::AddIOp>(loc, loopLBVar, scaled);
2193
+
2194
+ // The argument is not currently in memory, so make a temporary for the
2195
+ // argument, and store it there, then bind that location to the argument.
2196
+ mlir::Operation *storeOp =
2197
+ createAndSetPrivatizedLoopVar (converter, loc, userVal, iv);
2198
+
2199
+ firOpBuilder.setInsertionPointAfter (storeOp);
2200
+ return {iv};
2201
+ };
2202
+
2203
+ // Create the omp.canonical_loop operation
2204
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
2205
+ OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
2206
+ directive)
2207
+ .setClauses (&item->clauses )
2208
+ .setDataSharingProcessor (&dsp)
2209
+ .setGenRegionEntryCb (ivCallback),
2210
+ queue, item, tripcount, cli);
2211
+
2212
+ firOpBuilder.setInsertionPointAfter (canonLoop);
2213
+ return canonLoop;
2214
+ }
2215
+
2216
+ static void genUnrollOp (Fortran::lower::AbstractConverter &converter,
2217
+ Fortran::lower::SymMap &symTable,
2218
+ lower::StatementContext &stmtCtx,
2219
+ Fortran::semantics::SemanticsContext &semaCtx,
2220
+ Fortran::lower::pft::Evaluation &eval,
2221
+ mlir::Location loc, const ConstructQueue &queue,
2222
+ ConstructQueue::const_iterator item) {
2223
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2224
+
2225
+ mlir::omp::LoopRelatedClauseOps loopInfo;
2226
+ llvm::SmallVector<const semantics::Symbol *> iv;
2227
+ collectLoopRelatedInfo (converter, loc, eval, item->clauses , loopInfo, iv);
2228
+
2229
+ // Clauses for unrolling not yet implemnted
2230
+ ClauseProcessor cp (converter, semaCtx, item->clauses );
2231
+ cp.processTODO <clause::Partial, clause::Full>(
2232
+ loc, llvm::omp::Directive::OMPD_unroll);
2233
+
2234
+ // Even though unroll does not support data-sharing clauses, but this is
2235
+ // required to fill the symbol table.
2236
+ DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
2237
+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2238
+ /* useDelayedPrivatization=*/ false , symTable);
2239
+ dsp.processStep1 ();
2240
+
2241
+ // Emit the associated loop
2242
+ auto canonLoop =
2243
+ genCanonicalLoopOp (converter, symTable, semaCtx, eval, loc, queue, item,
2244
+ iv, llvm::omp::Directive::OMPD_unroll, dsp);
2245
+
2246
+ // Apply unrolling to it
2247
+ auto cli = canonLoop.getCli ();
2248
+ firOpBuilder.create <mlir::omp::UnrollHeuristicOp>(loc, cli);
2249
+ }
2250
+
2072
2251
static mlir::omp::MaskedOp
2073
2252
genMaskedOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2074
2253
lower::StatementContext &stmtCtx,
@@ -3249,12 +3428,14 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
3249
3428
newOp = genTeamsOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
3250
3429
item);
3251
3430
break ;
3252
- case llvm::omp::Directive::OMPD_tile:
3253
- case llvm::omp::Directive::OMPD_unroll: {
3431
+ case llvm::omp::Directive::OMPD_tile: {
3254
3432
unsigned version = semaCtx.langOptions ().OpenMPVersion ;
3255
3433
TODO (loc, " Unhandled loop directive (" +
3256
3434
llvm::omp::getOpenMPDirectiveName (dir, version) + " )" );
3257
3435
}
3436
+ case llvm::omp::Directive::OMPD_unroll:
3437
+ genUnrollOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
3438
+ break ;
3258
3439
// case llvm::omp::Directive::OMPD_workdistribute:
3259
3440
case llvm::omp::Directive::OMPD_workshare:
3260
3441
newOp = genWorkshareOp (converter, symTable, stmtCtx, semaCtx, eval, loc,
@@ -3690,12 +3871,25 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3690
3871
if (auto *ompNestedLoopCons{
3691
3872
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
3692
3873
&*optLoopCons)}) {
3693
- genOMP (converter, symTable, semaCtx, eval, ompNestedLoopCons->value ());
3874
+ llvm::omp::Directive nestedDirective =
3875
+ getOpenMPDirectiveEnum (*ompNestedLoopCons);
3876
+ switch (nestedDirective) {
3877
+ case llvm::omp::Directive::OMPD_tile:
3878
+ // Emit the omp.loop_nest with annotation for tiling
3879
+ genOMP (converter, symTable, semaCtx, eval, ompNestedLoopCons->value ());
3880
+ break ;
3881
+ default : {
3882
+ unsigned version = semaCtx.langOptions ().OpenMPVersion ;
3883
+ TODO (currentLocation,
3884
+ " Applying a loop-associated on the loop generated by the " +
3885
+ llvm::omp::getOpenMPDirectiveName (nestedDirective, version) +
3886
+ " construct" );
3887
+ }
3888
+ }
3694
3889
}
3695
3890
}
3696
3891
3697
- llvm::omp::Directive directive =
3698
- std::get<parser::OmpLoopDirective>(beginLoopDirective.t ).v ;
3892
+ llvm::omp::Directive directive = getOpenMPDirectiveEnum (beginLoopDirective);
3699
3893
const parser::CharBlock &source =
3700
3894
std::get<parser::OmpLoopDirective>(beginLoopDirective.t ).source ;
3701
3895
ConstructQueue queue{
0 commit comments