@@ -3114,50 +3114,127 @@ class FirConverter : public Fortran::lower::AbstractConverter {
3114
3114
llvm::SmallVector<mlir::Value> ivValues;
3115
3115
Fortran::lower::pft::Evaluation *loopEval =
3116
3116
&getEval ().getFirstNestedEvaluation ();
3117
- for (unsigned i = 0 ; i < nestedLoops; ++i) {
3118
- const Fortran::parser::LoopControl *loopControl;
3119
- mlir::Location crtLoc = loc;
3120
- if (i == 0 ) {
3121
- loopControl = &*outerDoConstruct->GetLoopControl ();
3122
- crtLoc =
3123
- genLocation (Fortran::parser::FindSourceLocation (outerDoConstruct));
3124
- } else {
3125
- auto *doCons = loopEval->getIf <Fortran::parser::DoConstruct>();
3126
- assert (doCons && " expect do construct" );
3127
- loopControl = &*doCons->GetLoopControl ();
3128
- crtLoc = genLocation (Fortran::parser::FindSourceLocation (*doCons));
3117
+ if (outerDoConstruct->IsDoConcurrent ()) {
3118
+ // Handle DO CONCURRENT
3119
+ locs.push_back (
3120
+ genLocation (Fortran::parser::FindSourceLocation (outerDoConstruct)));
3121
+ const Fortran::parser::LoopControl *loopControl =
3122
+ &*outerDoConstruct->GetLoopControl ();
3123
+ const auto &concurrent =
3124
+ std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u );
3125
+
3126
+ if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t )
3127
+ .empty ())
3128
+ TODO (loc, " DO CONCURRENT with locality spec" );
3129
+
3130
+ const auto &concurrentHeader =
3131
+ std::get<Fortran::parser::ConcurrentHeader>(concurrent.t );
3132
+ const auto &controls =
3133
+ std::get<std::list<Fortran::parser::ConcurrentControl>>(
3134
+ concurrentHeader.t );
3135
+
3136
+ for (const auto &control : controls) {
3137
+ mlir::Value lb = fir::getBase (genExprValue (
3138
+ *Fortran::semantics::GetExpr (std::get<1 >(control.t )), stmtCtx));
3139
+ mlir::Value ub = fir::getBase (genExprValue (
3140
+ *Fortran::semantics::GetExpr (std::get<2 >(control.t )), stmtCtx));
3141
+ mlir::Value step;
3142
+
3143
+ if (const auto &expr =
3144
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
3145
+ control.t ))
3146
+ step = fir::getBase (
3147
+ genExprValue (*Fortran::semantics::GetExpr (*expr), stmtCtx));
3148
+ else
3149
+ step = builder->create <mlir::arith::ConstantIndexOp>(
3150
+ loc, 1 ); // Use index type directly
3151
+
3152
+ // Ensure lb, ub, and step are of index type using fir.convert
3153
+ mlir::Type indexType = builder->getIndexType ();
3154
+ lb = builder->create <fir::ConvertOp>(loc, indexType, lb);
3155
+ ub = builder->create <fir::ConvertOp>(loc, indexType, ub);
3156
+ step = builder->create <fir::ConvertOp>(loc, indexType, step);
3157
+
3158
+ lbs.push_back (lb);
3159
+ ubs.push_back (ub);
3160
+ steps.push_back (step);
3161
+
3162
+ const auto &name = std::get<Fortran::parser::Name>(control.t );
3163
+
3164
+ // Handle induction variable
3165
+ mlir::Value ivValue = getSymbolAddress (*name.symbol );
3166
+ std::size_t ivTypeSize = name.symbol ->size ();
3167
+ if (ivTypeSize == 0 )
3168
+ llvm::report_fatal_error (" unexpected induction variable size" );
3169
+ mlir::Type ivTy = builder->getIntegerType (ivTypeSize * 8 );
3170
+
3171
+ if (!ivValue) {
3172
+ // DO CONCURRENT induction variables are not mapped yet since they are
3173
+ // local to the DO CONCURRENT scope.
3174
+ mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint ();
3175
+ builder->setInsertionPointToStart (builder->getAllocaBlock ());
3176
+ ivValue = builder->createTemporaryAlloc (
3177
+ loc, ivTy, toStringRef (name.symbol ->name ()));
3178
+ builder->restoreInsertionPoint (insPt);
3179
+ }
3180
+
3181
+ // Create the hlfir.declare operation using the symbol's name
3182
+ auto declareOp = builder->create <hlfir::DeclareOp>(
3183
+ loc, ivValue, toStringRef (name.symbol ->name ()));
3184
+ ivValue = declareOp.getResult (0 );
3185
+
3186
+ // Bind the symbol to the declared variable
3187
+ bindSymbol (*name.symbol , ivValue);
3188
+ ivValues.push_back (ivValue);
3189
+ ivTypes.push_back (ivTy);
3190
+ ivLocs.push_back (loc);
3129
3191
}
3192
+ } else {
3193
+ for (unsigned i = 0 ; i < nestedLoops; ++i) {
3194
+ const Fortran::parser::LoopControl *loopControl;
3195
+ mlir::Location crtLoc = loc;
3196
+ if (i == 0 ) {
3197
+ loopControl = &*outerDoConstruct->GetLoopControl ();
3198
+ crtLoc = genLocation (
3199
+ Fortran::parser::FindSourceLocation (outerDoConstruct));
3200
+ } else {
3201
+ auto *doCons = loopEval->getIf <Fortran::parser::DoConstruct>();
3202
+ assert (doCons && " expect do construct" );
3203
+ loopControl = &*doCons->GetLoopControl ();
3204
+ crtLoc = genLocation (Fortran::parser::FindSourceLocation (*doCons));
3205
+ }
3130
3206
3131
- locs.push_back (crtLoc);
3132
-
3133
- const Fortran::parser::LoopControl::Bounds *bounds =
3134
- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
3135
- assert (bounds && " Expected bounds on the loop construct" );
3136
-
3137
- Fortran::semantics::Symbol &ivSym =
3138
- bounds->name .thing .symbol ->GetUltimate ();
3139
- ivValues.push_back (getSymbolAddress (ivSym));
3140
-
3141
- lbs.push_back (builder->createConvert (
3142
- crtLoc, idxTy,
3143
- fir::getBase (genExprValue (*Fortran::semantics::GetExpr (bounds->lower ),
3144
- stmtCtx))));
3145
- ubs.push_back (builder->createConvert (
3146
- crtLoc, idxTy,
3147
- fir::getBase (genExprValue (*Fortran::semantics::GetExpr (bounds->upper ),
3148
- stmtCtx))));
3149
- if (bounds->step )
3150
- steps.push_back (builder->createConvert (
3207
+ locs.push_back (crtLoc);
3208
+
3209
+ const Fortran::parser::LoopControl::Bounds *bounds =
3210
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
3211
+ assert (bounds && " Expected bounds on the loop construct" );
3212
+
3213
+ Fortran::semantics::Symbol &ivSym =
3214
+ bounds->name .thing .symbol ->GetUltimate ();
3215
+ ivValues.push_back (getSymbolAddress (ivSym));
3216
+
3217
+ lbs.push_back (builder->createConvert (
3151
3218
crtLoc, idxTy,
3152
3219
fir::getBase (genExprValue (
3153
- *Fortran::semantics::GetExpr (bounds->step ), stmtCtx))));
3154
- else // If `step` is not present, assume it is `1`.
3155
- steps.push_back (builder->createIntegerConstant (loc, idxTy, 1 ));
3156
-
3157
- ivTypes.push_back (idxTy);
3158
- ivLocs.push_back (crtLoc);
3159
- if (i < nestedLoops - 1 )
3160
- loopEval = &*std::next (loopEval->getNestedEvaluations ().begin ());
3220
+ *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx))));
3221
+ ubs.push_back (builder->createConvert (
3222
+ crtLoc, idxTy,
3223
+ fir::getBase (genExprValue (
3224
+ *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx))));
3225
+ if (bounds->step )
3226
+ steps.push_back (builder->createConvert (
3227
+ crtLoc, idxTy,
3228
+ fir::getBase (genExprValue (
3229
+ *Fortran::semantics::GetExpr (bounds->step ), stmtCtx))));
3230
+ else // If `step` is not present, assume it is `1`.
3231
+ steps.push_back (builder->createIntegerConstant (loc, idxTy, 1 ));
3232
+
3233
+ ivTypes.push_back (idxTy);
3234
+ ivLocs.push_back (crtLoc);
3235
+ if (i < nestedLoops - 1 )
3236
+ loopEval = &*std::next (loopEval->getNestedEvaluations ().begin ());
3237
+ }
3161
3238
}
3162
3239
3163
3240
auto op = builder->create <cuf::KernelOp>(
0 commit comments