@@ -3114,50 +3114,127 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31143114 llvm::SmallVector<mlir::Value> ivValues;
31153115 Fortran::lower::pft::Evaluation *loopEval =
31163116 &getEval ().getFirstNestedEvaluation ();
3117- for (unsigned i = 0 ; i < nestedLoops; ++i) {
3118- const Fortran::parser::LoopControl *loopControl;
3119- mlir::Location crtLoc = loc;
3120- if (i == 0 ) {
3121- loopControl = &*outerDoConstruct->GetLoopControl ();
3122- crtLoc =
3123- genLocation (Fortran::parser::FindSourceLocation (outerDoConstruct));
3124- } else {
3125- auto *doCons = loopEval->getIf <Fortran::parser::DoConstruct>();
3126- assert (doCons && " expect do construct" );
3127- loopControl = &*doCons->GetLoopControl ();
3128- crtLoc = genLocation (Fortran::parser::FindSourceLocation (*doCons));
3117+ if (outerDoConstruct->IsDoConcurrent ()) {
3118+ // Handle DO CONCURRENT
3119+ locs.push_back (
3120+ genLocation (Fortran::parser::FindSourceLocation (outerDoConstruct)));
3121+ const Fortran::parser::LoopControl *loopControl =
3122+ &*outerDoConstruct->GetLoopControl ();
3123+ const auto &concurrent =
3124+ std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u );
3125+
3126+ if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t )
3127+ .empty ())
3128+ TODO (loc, " DO CONCURRENT with locality spec" );
3129+
3130+ const auto &concurrentHeader =
3131+ std::get<Fortran::parser::ConcurrentHeader>(concurrent.t );
3132+ const auto &controls =
3133+ std::get<std::list<Fortran::parser::ConcurrentControl>>(
3134+ concurrentHeader.t );
3135+
3136+ for (const auto &control : controls) {
3137+ mlir::Value lb = fir::getBase (genExprValue (
3138+ *Fortran::semantics::GetExpr (std::get<1 >(control.t )), stmtCtx));
3139+ mlir::Value ub = fir::getBase (genExprValue (
3140+ *Fortran::semantics::GetExpr (std::get<2 >(control.t )), stmtCtx));
3141+ mlir::Value step;
3142+
3143+ if (const auto &expr =
3144+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
3145+ control.t ))
3146+ step = fir::getBase (
3147+ genExprValue (*Fortran::semantics::GetExpr (*expr), stmtCtx));
3148+ else
3149+ step = builder->create <mlir::arith::ConstantIndexOp>(
3150+ loc, 1 ); // Use index type directly
3151+
3152+ // Ensure lb, ub, and step are of index type using fir.convert
3153+ mlir::Type indexType = builder->getIndexType ();
3154+ lb = builder->create <fir::ConvertOp>(loc, indexType, lb);
3155+ ub = builder->create <fir::ConvertOp>(loc, indexType, ub);
3156+ step = builder->create <fir::ConvertOp>(loc, indexType, step);
3157+
3158+ lbs.push_back (lb);
3159+ ubs.push_back (ub);
3160+ steps.push_back (step);
3161+
3162+ const auto &name = std::get<Fortran::parser::Name>(control.t );
3163+
3164+ // Handle induction variable
3165+ mlir::Value ivValue = getSymbolAddress (*name.symbol );
3166+ std::size_t ivTypeSize = name.symbol ->size ();
3167+ if (ivTypeSize == 0 )
3168+ llvm::report_fatal_error (" unexpected induction variable size" );
3169+ mlir::Type ivTy = builder->getIntegerType (ivTypeSize * 8 );
3170+
3171+ if (!ivValue) {
3172+ // DO CONCURRENT induction variables are not mapped yet since they are
3173+ // local to the DO CONCURRENT scope.
3174+ mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint ();
3175+ builder->setInsertionPointToStart (builder->getAllocaBlock ());
3176+ ivValue = builder->createTemporaryAlloc (
3177+ loc, ivTy, toStringRef (name.symbol ->name ()));
3178+ builder->restoreInsertionPoint (insPt);
3179+ }
3180+
3181+ // Create the hlfir.declare operation using the symbol's name
3182+ auto declareOp = builder->create <hlfir::DeclareOp>(
3183+ loc, ivValue, toStringRef (name.symbol ->name ()));
3184+ ivValue = declareOp.getResult (0 );
3185+
3186+ // Bind the symbol to the declared variable
3187+ bindSymbol (*name.symbol , ivValue);
3188+ ivValues.push_back (ivValue);
3189+ ivTypes.push_back (ivTy);
3190+ ivLocs.push_back (loc);
31293191 }
3192+ } else {
3193+ for (unsigned i = 0 ; i < nestedLoops; ++i) {
3194+ const Fortran::parser::LoopControl *loopControl;
3195+ mlir::Location crtLoc = loc;
3196+ if (i == 0 ) {
3197+ loopControl = &*outerDoConstruct->GetLoopControl ();
3198+ crtLoc = genLocation (
3199+ Fortran::parser::FindSourceLocation (outerDoConstruct));
3200+ } else {
3201+ auto *doCons = loopEval->getIf <Fortran::parser::DoConstruct>();
3202+ assert (doCons && " expect do construct" );
3203+ loopControl = &*doCons->GetLoopControl ();
3204+ crtLoc = genLocation (Fortran::parser::FindSourceLocation (*doCons));
3205+ }
31303206
3131- locs.push_back (crtLoc);
3132-
3133- const Fortran::parser::LoopControl::Bounds *bounds =
3134- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
3135- assert (bounds && " Expected bounds on the loop construct" );
3136-
3137- Fortran::semantics::Symbol &ivSym =
3138- bounds->name .thing .symbol ->GetUltimate ();
3139- ivValues.push_back (getSymbolAddress (ivSym));
3140-
3141- lbs.push_back (builder->createConvert (
3142- crtLoc, idxTy,
3143- fir::getBase (genExprValue (*Fortran::semantics::GetExpr (bounds->lower ),
3144- stmtCtx))));
3145- ubs.push_back (builder->createConvert (
3146- crtLoc, idxTy,
3147- fir::getBase (genExprValue (*Fortran::semantics::GetExpr (bounds->upper ),
3148- stmtCtx))));
3149- if (bounds->step )
3150- steps.push_back (builder->createConvert (
3207+ locs.push_back (crtLoc);
3208+
3209+ const Fortran::parser::LoopControl::Bounds *bounds =
3210+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
3211+ assert (bounds && " Expected bounds on the loop construct" );
3212+
3213+ Fortran::semantics::Symbol &ivSym =
3214+ bounds->name .thing .symbol ->GetUltimate ();
3215+ ivValues.push_back (getSymbolAddress (ivSym));
3216+
3217+ lbs.push_back (builder->createConvert (
31513218 crtLoc, idxTy,
31523219 fir::getBase (genExprValue (
3153- *Fortran::semantics::GetExpr (bounds->step ), stmtCtx))));
3154- else // If `step` is not present, assume it is `1`.
3155- steps.push_back (builder->createIntegerConstant (loc, idxTy, 1 ));
3156-
3157- ivTypes.push_back (idxTy);
3158- ivLocs.push_back (crtLoc);
3159- if (i < nestedLoops - 1 )
3160- loopEval = &*std::next (loopEval->getNestedEvaluations ().begin ());
3220+ *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx))));
3221+ ubs.push_back (builder->createConvert (
3222+ crtLoc, idxTy,
3223+ fir::getBase (genExprValue (
3224+ *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx))));
3225+ if (bounds->step )
3226+ steps.push_back (builder->createConvert (
3227+ crtLoc, idxTy,
3228+ fir::getBase (genExprValue (
3229+ *Fortran::semantics::GetExpr (bounds->step ), stmtCtx))));
3230+ else // If `step` is not present, assume it is `1`.
3231+ steps.push_back (builder->createIntegerConstant (loc, idxTy, 1 ));
3232+
3233+ ivTypes.push_back (idxTy);
3234+ ivLocs.push_back (crtLoc);
3235+ if (i < nestedLoops - 1 )
3236+ loopEval = &*std::next (loopEval->getNestedEvaluations ().begin ());
3237+ }
31613238 }
31623239
31633240 auto op = builder->create <cuf::KernelOp>(
0 commit comments