diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 36e58e456dea3..74e8115c9f9f0 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3074,50 +3074,127 @@ class FirConverter : public Fortran::lower::AbstractConverter { llvm::SmallVector ivValues; Fortran::lower::pft::Evaluation *loopEval = &getEval().getFirstNestedEvaluation(); - for (unsigned i = 0; i < nestedLoops; ++i) { - const Fortran::parser::LoopControl *loopControl; - mlir::Location crtLoc = loc; - if (i == 0) { - loopControl = &*outerDoConstruct->GetLoopControl(); - crtLoc = - genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)); - } else { - auto *doCons = loopEval->getIf(); - assert(doCons && "expect do construct"); - loopControl = &*doCons->GetLoopControl(); - crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons)); + if (outerDoConstruct->IsDoConcurrent()) { + // Handle DO CONCURRENT + locs.push_back( + genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct))); + const Fortran::parser::LoopControl *loopControl = + &*outerDoConstruct->GetLoopControl(); + const auto &concurrent = + std::get(loopControl->u); + + if (!std::get>(concurrent.t) + .empty()) + TODO(loc, "DO CONCURRENT with locality spec"); + + const auto &concurrentHeader = + std::get(concurrent.t); + const auto &controls = + std::get>( + concurrentHeader.t); + + for (const auto &control : controls) { + mlir::Value lb = fir::getBase(genExprValue( + *Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx)); + mlir::Value ub = fir::getBase(genExprValue( + *Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx)); + mlir::Value step; + + if (const auto &expr = + std::get>( + control.t)) + step = fir::getBase( + genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx)); + else + step = builder->create( + loc, 1); // Use index type directly + + // Ensure lb, ub, and step are of index type using fir.convert + mlir::Type indexType = builder->getIndexType(); + lb = builder->create(loc, indexType, lb); + ub = builder->create(loc, indexType, ub); + step = builder->create(loc, indexType, step); + + lbs.push_back(lb); + ubs.push_back(ub); + steps.push_back(step); + + const auto &name = std::get(control.t); + + // Handle induction variable + mlir::Value ivValue = getSymbolAddress(*name.symbol); + std::size_t ivTypeSize = name.symbol->size(); + if (ivTypeSize == 0) + llvm::report_fatal_error("unexpected induction variable size"); + mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8); + + if (!ivValue) { + // DO CONCURRENT induction variables are not mapped yet since they are + // local to the DO CONCURRENT scope. + mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint(); + builder->setInsertionPointToStart(builder->getAllocaBlock()); + ivValue = builder->createTemporaryAlloc( + loc, ivTy, toStringRef(name.symbol->name())); + builder->restoreInsertionPoint(insPt); + } + + // Create the hlfir.declare operation using the symbol's name + auto declareOp = builder->create( + loc, ivValue, toStringRef(name.symbol->name())); + ivValue = declareOp.getResult(0); + + // Bind the symbol to the declared variable + bindSymbol(*name.symbol, ivValue); + ivValues.push_back(ivValue); + ivTypes.push_back(ivTy); + ivLocs.push_back(loc); } + } else { + for (unsigned i = 0; i < nestedLoops; ++i) { + const Fortran::parser::LoopControl *loopControl; + mlir::Location crtLoc = loc; + if (i == 0) { + loopControl = &*outerDoConstruct->GetLoopControl(); + crtLoc = genLocation( + Fortran::parser::FindSourceLocation(outerDoConstruct)); + } else { + auto *doCons = loopEval->getIf(); + assert(doCons && "expect do construct"); + loopControl = &*doCons->GetLoopControl(); + crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons)); + } - locs.push_back(crtLoc); - - const Fortran::parser::LoopControl::Bounds *bounds = - std::get_if(&loopControl->u); - assert(bounds && "Expected bounds on the loop construct"); - - Fortran::semantics::Symbol &ivSym = - bounds->name.thing.symbol->GetUltimate(); - ivValues.push_back(getSymbolAddress(ivSym)); - - lbs.push_back(builder->createConvert( - crtLoc, idxTy, - fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower), - stmtCtx)))); - ubs.push_back(builder->createConvert( - crtLoc, idxTy, - fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper), - stmtCtx)))); - if (bounds->step) - steps.push_back(builder->createConvert( + locs.push_back(crtLoc); + + const Fortran::parser::LoopControl::Bounds *bounds = + std::get_if(&loopControl->u); + assert(bounds && "Expected bounds on the loop construct"); + + Fortran::semantics::Symbol &ivSym = + bounds->name.thing.symbol->GetUltimate(); + ivValues.push_back(getSymbolAddress(ivSym)); + + lbs.push_back(builder->createConvert( crtLoc, idxTy, fir::getBase(genExprValue( - *Fortran::semantics::GetExpr(bounds->step), stmtCtx)))); - else // If `step` is not present, assume it is `1`. - steps.push_back(builder->createIntegerConstant(loc, idxTy, 1)); - - ivTypes.push_back(idxTy); - ivLocs.push_back(crtLoc); - if (i < nestedLoops - 1) - loopEval = &*std::next(loopEval->getNestedEvaluations().begin()); + *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)))); + ubs.push_back(builder->createConvert( + crtLoc, idxTy, + fir::getBase(genExprValue( + *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)))); + if (bounds->step) + steps.push_back(builder->createConvert( + crtLoc, idxTy, + fir::getBase(genExprValue( + *Fortran::semantics::GetExpr(bounds->step), stmtCtx)))); + else // If `step` is not present, assume it is `1`. + steps.push_back(builder->createIntegerConstant(loc, idxTy, 1)); + + ivTypes.push_back(idxTy); + ivLocs.push_back(crtLoc); + if (i < nestedLoops - 1) + loopEval = &*std::next(loopEval->getNestedEvaluations().begin()); + } } auto op = builder->create( diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp index c85a84ea5527f..cb7a383284e63 100644 --- a/flang/lib/Semantics/check-cuda.cpp +++ b/flang/lib/Semantics/check-cuda.cpp @@ -525,6 +525,21 @@ static int DoConstructTightNesting( return 0; } innerBlock = &std::get(doConstruct->t); + if (doConstruct->IsDoConcurrent()) { + const auto &loopControl = doConstruct->GetLoopControl(); + if (loopControl) { + if (const auto *concurrentControl{ + std::get_if(&loopControl->u)}) { + const auto &concurrentHeader = + std::get(concurrentControl->t); + const auto &controls = + std::get>( + concurrentHeader.t); + return controls.size(); + } + } + return 0; + } if (innerBlock->size() == 1) { if (const auto *execConstruct{ std::get_if(&innerBlock->front().u)}) { @@ -598,9 +613,14 @@ void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) { std::get>(x.t))}; const parser::Block *innerBlock{nullptr}; if (DoConstructTightNesting(doConstruct, innerBlock) < depth) { - context_.Say(source, - "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US, - std::intmax_t{depth}); + if (doConstruct && doConstruct->IsDoConcurrent()) + context_.Say(source, + "!$CUF KERNEL DO (%jd) must be followed by a DO CONCURRENT construct with at least %jd indices"_err_en_US, + std::intmax_t{depth}, std::intmax_t{depth}); + else + context_.Say(source, + "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US, + std::intmax_t{depth}); } if (innerBlock) { DeviceContextChecker{context_}.Check(*innerBlock); diff --git a/flang/test/Lower/CUDA/cuda-doconc.cuf b/flang/test/Lower/CUDA/cuda-doconc.cuf new file mode 100644 index 0000000000000..32cd1676b22f4 --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-doconc.cuf @@ -0,0 +1,39 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +! Check if do concurrent works inside cuf kernel directive + +subroutine doconc1 + integer :: i, n + integer, managed :: a(3) + a(:) = -1 + n = 3 + n = n - 1 + !$cuf kernel do + do concurrent(i=1:n) + a(i) = 1 + end do +end + +! CHECK: func.func @_QPdoconc1() { +! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: cuf.kernel<<<*, *>>> +! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref + +subroutine doconc2 + integer :: i, j, m, n + integer, managed :: a(2, 4) + m = 2 + n = 4 + a(:,:) = -1 + !$cuf kernel do + do concurrent(i=1:m,j=1:n) + a(i,j) = i+j + end do +end + +! CHECK: func.func @_QPdoconc2() { +! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) { +! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref +! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref diff --git a/flang/test/Semantics/cuf09.cuf b/flang/test/Semantics/cuf09.cuf index 7d32e0d70ba36..a8c62db65c6d5 100644 --- a/flang/test/Semantics/cuf09.cuf +++ b/flang/test/Semantics/cuf09.cuf @@ -133,6 +133,10 @@ program main !$cuf kernel do <<< 1, 2 >>> do concurrent (j=1:10) end do + !ERROR: !$CUF KERNEL DO (2) must be followed by a DO CONCURRENT construct with at least 2 indices + !$cuf kernel do(2) <<< 1, 2 >>> + do concurrent (j=1:10) + end do !$cuf kernel do <<< 1, 2 >>> do 1 j=1,10 1 continue ! ok