Skip to content

Commit a67566b

Browse files
wangzpgiclementval
andauthored
Allow do concurrent inside cuf kernel directive (#127693)
Allow do concurrent inside cuf kernel directive to avoid the following Lowering error: ``` void {anonymous}::FirConverter::genFIR(const Fortran::parser::CUFKernelDoConstruct&): Assertion `bounds && "Expected bounds on the loop construct"' failed. ``` --------- Co-authored-by: Valentin Clement (バレンタイン クレメン) <[email protected]>
1 parent 5981335 commit a67566b

File tree

4 files changed

+183
-43
lines changed

4 files changed

+183
-43
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 117 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,50 +3114,127 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31143114
llvm::SmallVector<mlir::Value> ivValues;
31153115
Fortran::lower::pft::Evaluation *loopEval =
31163116
&getEval().getFirstNestedEvaluation();
3117-
for (unsigned i = 0; i < nestedLoops; ++i) {
3118-
const Fortran::parser::LoopControl *loopControl;
3119-
mlir::Location crtLoc = loc;
3120-
if (i == 0) {
3121-
loopControl = &*outerDoConstruct->GetLoopControl();
3122-
crtLoc =
3123-
genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
3124-
} else {
3125-
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
3126-
assert(doCons && "expect do construct");
3127-
loopControl = &*doCons->GetLoopControl();
3128-
crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
3117+
if (outerDoConstruct->IsDoConcurrent()) {
3118+
// Handle DO CONCURRENT
3119+
locs.push_back(
3120+
genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct)));
3121+
const Fortran::parser::LoopControl *loopControl =
3122+
&*outerDoConstruct->GetLoopControl();
3123+
const auto &concurrent =
3124+
std::get<Fortran::parser::LoopControl::Concurrent>(loopControl->u);
3125+
3126+
if (!std::get<std::list<Fortran::parser::LocalitySpec>>(concurrent.t)
3127+
.empty())
3128+
TODO(loc, "DO CONCURRENT with locality spec");
3129+
3130+
const auto &concurrentHeader =
3131+
std::get<Fortran::parser::ConcurrentHeader>(concurrent.t);
3132+
const auto &controls =
3133+
std::get<std::list<Fortran::parser::ConcurrentControl>>(
3134+
concurrentHeader.t);
3135+
3136+
for (const auto &control : controls) {
3137+
mlir::Value lb = fir::getBase(genExprValue(
3138+
*Fortran::semantics::GetExpr(std::get<1>(control.t)), stmtCtx));
3139+
mlir::Value ub = fir::getBase(genExprValue(
3140+
*Fortran::semantics::GetExpr(std::get<2>(control.t)), stmtCtx));
3141+
mlir::Value step;
3142+
3143+
if (const auto &expr =
3144+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
3145+
control.t))
3146+
step = fir::getBase(
3147+
genExprValue(*Fortran::semantics::GetExpr(*expr), stmtCtx));
3148+
else
3149+
step = builder->create<mlir::arith::ConstantIndexOp>(
3150+
loc, 1); // Use index type directly
3151+
3152+
// Ensure lb, ub, and step are of index type using fir.convert
3153+
mlir::Type indexType = builder->getIndexType();
3154+
lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
3155+
ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
3156+
step = builder->create<fir::ConvertOp>(loc, indexType, step);
3157+
3158+
lbs.push_back(lb);
3159+
ubs.push_back(ub);
3160+
steps.push_back(step);
3161+
3162+
const auto &name = std::get<Fortran::parser::Name>(control.t);
3163+
3164+
// Handle induction variable
3165+
mlir::Value ivValue = getSymbolAddress(*name.symbol);
3166+
std::size_t ivTypeSize = name.symbol->size();
3167+
if (ivTypeSize == 0)
3168+
llvm::report_fatal_error("unexpected induction variable size");
3169+
mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);
3170+
3171+
if (!ivValue) {
3172+
// DO CONCURRENT induction variables are not mapped yet since they are
3173+
// local to the DO CONCURRENT scope.
3174+
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
3175+
builder->setInsertionPointToStart(builder->getAllocaBlock());
3176+
ivValue = builder->createTemporaryAlloc(
3177+
loc, ivTy, toStringRef(name.symbol->name()));
3178+
builder->restoreInsertionPoint(insPt);
3179+
}
3180+
3181+
// Create the hlfir.declare operation using the symbol's name
3182+
auto declareOp = builder->create<hlfir::DeclareOp>(
3183+
loc, ivValue, toStringRef(name.symbol->name()));
3184+
ivValue = declareOp.getResult(0);
3185+
3186+
// Bind the symbol to the declared variable
3187+
bindSymbol(*name.symbol, ivValue);
3188+
ivValues.push_back(ivValue);
3189+
ivTypes.push_back(ivTy);
3190+
ivLocs.push_back(loc);
31293191
}
3192+
} else {
3193+
for (unsigned i = 0; i < nestedLoops; ++i) {
3194+
const Fortran::parser::LoopControl *loopControl;
3195+
mlir::Location crtLoc = loc;
3196+
if (i == 0) {
3197+
loopControl = &*outerDoConstruct->GetLoopControl();
3198+
crtLoc = genLocation(
3199+
Fortran::parser::FindSourceLocation(outerDoConstruct));
3200+
} else {
3201+
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
3202+
assert(doCons && "expect do construct");
3203+
loopControl = &*doCons->GetLoopControl();
3204+
crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
3205+
}
31303206

3131-
locs.push_back(crtLoc);
3132-
3133-
const Fortran::parser::LoopControl::Bounds *bounds =
3134-
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
3135-
assert(bounds && "Expected bounds on the loop construct");
3136-
3137-
Fortran::semantics::Symbol &ivSym =
3138-
bounds->name.thing.symbol->GetUltimate();
3139-
ivValues.push_back(getSymbolAddress(ivSym));
3140-
3141-
lbs.push_back(builder->createConvert(
3142-
crtLoc, idxTy,
3143-
fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
3144-
stmtCtx))));
3145-
ubs.push_back(builder->createConvert(
3146-
crtLoc, idxTy,
3147-
fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
3148-
stmtCtx))));
3149-
if (bounds->step)
3150-
steps.push_back(builder->createConvert(
3207+
locs.push_back(crtLoc);
3208+
3209+
const Fortran::parser::LoopControl::Bounds *bounds =
3210+
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
3211+
assert(bounds && "Expected bounds on the loop construct");
3212+
3213+
Fortran::semantics::Symbol &ivSym =
3214+
bounds->name.thing.symbol->GetUltimate();
3215+
ivValues.push_back(getSymbolAddress(ivSym));
3216+
3217+
lbs.push_back(builder->createConvert(
31513218
crtLoc, idxTy,
31523219
fir::getBase(genExprValue(
3153-
*Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
3154-
else // If `step` is not present, assume it is `1`.
3155-
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
3156-
3157-
ivTypes.push_back(idxTy);
3158-
ivLocs.push_back(crtLoc);
3159-
if (i < nestedLoops - 1)
3160-
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
3220+
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
3221+
ubs.push_back(builder->createConvert(
3222+
crtLoc, idxTy,
3223+
fir::getBase(genExprValue(
3224+
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
3225+
if (bounds->step)
3226+
steps.push_back(builder->createConvert(
3227+
crtLoc, idxTy,
3228+
fir::getBase(genExprValue(
3229+
*Fortran::semantics::GetExpr(bounds->step), stmtCtx))));
3230+
else // If `step` is not present, assume it is `1`.
3231+
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
3232+
3233+
ivTypes.push_back(idxTy);
3234+
ivLocs.push_back(crtLoc);
3235+
if (i < nestedLoops - 1)
3236+
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
3237+
}
31613238
}
31623239

31633240
auto op = builder->create<cuf::KernelOp>(

flang/lib/Semantics/check-cuda.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,21 @@ static int DoConstructTightNesting(
525525
return 0;
526526
}
527527
innerBlock = &std::get<parser::Block>(doConstruct->t);
528+
if (doConstruct->IsDoConcurrent()) {
529+
const auto &loopControl = doConstruct->GetLoopControl();
530+
if (loopControl) {
531+
if (const auto *concurrentControl{
532+
std::get_if<parser::LoopControl::Concurrent>(&loopControl->u)}) {
533+
const auto &concurrentHeader =
534+
std::get<Fortran::parser::ConcurrentHeader>(concurrentControl->t);
535+
const auto &controls =
536+
std::get<std::list<Fortran::parser::ConcurrentControl>>(
537+
concurrentHeader.t);
538+
return controls.size();
539+
}
540+
}
541+
return 0;
542+
}
528543
if (innerBlock->size() == 1) {
529544
if (const auto *execConstruct{
530545
std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {
@@ -598,9 +613,14 @@ void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
598613
std::get<std::optional<parser::DoConstruct>>(x.t))};
599614
const parser::Block *innerBlock{nullptr};
600615
if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {
601-
context_.Say(source,
602-
"!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
603-
std::intmax_t{depth});
616+
if (doConstruct && doConstruct->IsDoConcurrent())
617+
context_.Say(source,
618+
"!$CUF KERNEL DO (%jd) must be followed by a DO CONCURRENT construct with at least %jd indices"_err_en_US,
619+
std::intmax_t{depth}, std::intmax_t{depth});
620+
else
621+
context_.Say(source,
622+
"!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
623+
std::intmax_t{depth});
604624
}
605625
if (innerBlock) {
606626
DeviceContextChecker<true>{context_}.Check(*innerBlock);

flang/test/Lower/CUDA/cuda-doconc.cuf

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Check if do concurrent works inside cuf kernel directive
4+
5+
subroutine doconc1
6+
integer :: i, n
7+
integer, managed :: a(3)
8+
a(:) = -1
9+
n = 3
10+
n = n - 1
11+
!$cuf kernel do
12+
do concurrent(i=1:n)
13+
a(i) = 1
14+
end do
15+
end
16+
17+
! CHECK: func.func @_QPdoconc1() {
18+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
19+
! CHECK: cuf.kernel<<<*, *>>>
20+
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>
21+
22+
subroutine doconc2
23+
integer :: i, j, m, n
24+
integer, managed :: a(2, 4)
25+
m = 2
26+
n = 4
27+
a(:,:) = -1
28+
!$cuf kernel do
29+
do concurrent(i=1:m,j=1:n)
30+
a(i,j) = i+j
31+
end do
32+
end
33+
34+
! CHECK: func.func @_QPdoconc2() {
35+
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
36+
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
37+
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
38+
! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<i32>
39+
! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<i32>

flang/test/Semantics/cuf09.cuf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ program main
133133
!$cuf kernel do <<< 1, 2 >>>
134134
do concurrent (j=1:10)
135135
end do
136+
!ERROR: !$CUF KERNEL DO (2) must be followed by a DO CONCURRENT construct with at least 2 indices
137+
!$cuf kernel do(2) <<< 1, 2 >>>
138+
do concurrent (j=1:10)
139+
end do
136140
!$cuf kernel do <<< 1, 2 >>>
137141
do 1 j=1,10
138142
1 continue ! ok

0 commit comments

Comments
 (0)