Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 125 additions & 40 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3074,50 +3074,135 @@ class FirConverter : public Fortran::lower::AbstractConverter {
llvm::SmallVector<mlir::Value> ivValues;
Fortran::lower::pft::Evaluation *loopEval =
&getEval().getFirstNestedEvaluation();
for (unsigned i = 0; i < nestedLoops; ++i) {
const Fortran::parser::LoopControl *loopControl;
mlir::Location crtLoc = loc;
if (i == 0) {
loopControl = &*outerDoConstruct->GetLoopControl();
crtLoc =
genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
} else {
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
loopControl = &*doCons->GetLoopControl();
crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
bool isDoConcurrent = outerDoConstruct->IsDoConcurrent();
if (isDoConcurrent) {
// Handle DO CONCURRENT
locs.push_back(
genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)));
const Fortran::parser::LoopControl *loopControl =
&*outerDoConstruct->GetLoopControl();
const auto &concurrent =
std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);

if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
.empty())
TODO(loc, "DO CONCURRENT with locality spec");

const auto &concurrentHeader =
std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
const auto &controls =
std::get<std::list<Fortran::parser::ConcurrentControl>>(
concurrentHeader.t);

for (const auto &control : controls) {
auto lb = fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx));
auto ub = fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx));
mlir::Value step;

if (const auto &expr =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
control.t)) {
step = fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx));
} else {
step = builder->create<mlir::arith::ConstantIndexOp>(
loc, 1); // Use index type directly
}

// Ensure lb, ub, and step are of index type using fir.convert
auto indexType = builder->getIndexType();
if (lb.getType() != indexType) {
lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
}
if (ub.getType() != indexType) {
ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
}
if (step.getType() != indexType) {
step = builder->create<fir::ConvertOp>(loc, indexType, step);
}

lbs.push_back(lb);
ubs.push_back(ub);
steps.push_back(step);

const auto &name = std::get<Fortran::parser::Name>(control.t);

// Handle induction variable
mlir::Value ivValue = getSymbolAddress(*name.symbol);
std::size_t ivTypeSize = name.symbol->size();
if (ivTypeSize == 0)
llvm::report_fatal_error("unexpected induction variable size");
mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);

if (!ivValue) {
// DO CONCURRENT induction variables are not mapped yet since they are
// local to the DO CONCURRENT scope.
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
builder->setInsertionPointToStart(builder->getAllocaBlock());
ivValue = builder->createTemporaryAlloc(
loc, ivTy, toStringRef(name.symbol->name()));
builder->restoreInsertionPoint(insPt);
}

// Create the hlfir.declare operation using the symbol's name
auto declareOp = builder->create<hlfir::DeclareOp>(
loc, ivValue, toStringRef(name.symbol->name()));
ivValue = declareOp.getResult(0);

// Bind the symbol to the declared variable
bindSymbol(*name.symbol, ivValue);
ivValues.push_back(ivValue);
ivTypes.push_back(ivTy);
ivLocs.push_back(loc);
}
} else {
for (unsigned i = 0; i < nestedLoops; ++i) {
const Fortran::parser::LoopControl *loopControl;
mlir::Location crtLoc = loc;
if (i == 0) {
loopControl = &*outerDoConstruct->GetLoopControl();
crtLoc = genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct));
} else {
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
loopControl = &*doCons->GetLoopControl();
crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
}

locs.push_back(crtLoc);

locs.push_back(crtLoc);

const Fortran::parser::LoopControl::Bounds *bounds =
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds on the loop construct");

Fortran::semantics::Symbol &ivSym =
bounds->name.thing.symbol->GetUltimate();
ivValues.push_back(getSymbolAddress(ivSym));

lbs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
stmtCtx))));
ubs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
stmtCtx))));
if (bounds->step)
steps.push_back(builder->createConvert(
const Fortran::parser::LoopControl::Bounds *bounds =
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds on the loop construct");

Fortran::semantics::Symbol &ivSym =
bounds->name.thing.symbol->GetUltimate();
ivValues.push_back(getSymbolAddress(ivSym));

lbs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
else // If `step` is not present, assume it is `1`.
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));

ivTypes.push_back(idxTy);
ivLocs.push_back(crtLoc);
if (i < nestedLoops - 1)
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
ubs.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
if (bounds->step)
steps.push_back(builder->createConvert(
crtLoc, idxTy,
fir::getBase(genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
else // If `step` is not present, assume it is `1`.
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));

ivTypes.push_back(idxTy);
ivLocs.push_back(crtLoc);
if (i < nestedLoops - 1)
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
}
}

auto op = builder->create<cuf::KernelOp>(
Expand Down
20 changes: 20 additions & 0 deletions flang/test/Lower/CUDA/cuda-doconc.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

! Check if do concurrent works inside cuf kernel directive

program main
integer :: i, n
integer, managed :: a(3)
a(:) = -1
n = 3
n = n - 1
!$cuf kernel do
do concurrent(i=1:n)
a(i) = 1
end do
end

! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} {
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.kernel<<<*, *>>>
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>