Skip to content

Commit f6ae559

Browse files
committed
[flang][cuda] Do not produce data transfer in offloaded do concurrent
1 parent deba201 commit f6ae559

File tree

9 files changed

+55
-9
lines changed

9 files changed

+55
-9
lines changed

flang/include/flang/Optimizer/Builder/CUFCommon.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
2727
mlir::SymbolTable &symTab);
2828

2929
bool isCUDADeviceContext(mlir::Operation *op);
30-
bool isCUDADeviceContext(mlir::Region &);
30+
bool isCUDADeviceContext(mlir::Region &, bool isStdParEnabled = false);
3131
bool isRegisteredDeviceGlobal(fir::GlobalOp op);
3232
bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
3333

flang/include/flang/Support/Fortran-features.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
5555
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
5656
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
5757
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
58-
InaccessibleDeferredOverride, CudaWarpMatchFunction)
58+
InaccessibleDeferredOverride, CudaWarpMatchFunction, StdPar)
5959

6060
// Portability and suspicious usage warnings
6161
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,

flang/lib/Lower/Bridge.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4886,7 +4886,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
48864886
mlir::Location loc = getCurrentLocation();
48874887
fir::FirOpBuilder &builder = getFirOpBuilder();
48884888

4889-
bool isInDeviceContext = cuf::isCUDADeviceContext(builder.getRegion());
4889+
bool isInDeviceContext = cuf::isCUDADeviceContext(
4890+
builder.getRegion(), getFoldingContext().languageFeatures().IsEnabled(
4891+
Fortran::common::LanguageFeature::StdPar));
48904892

48914893
bool isCUDATransfer =
48924894
IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext;

flang/lib/Optimizer/Builder/CUFCommon.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ bool cuf::isCUDADeviceContext(mlir::Operation *op) {
4343
// for it.
4444
// If the insertion point is inside an OpenACC region op, it is considered
4545
// device context.
46-
bool cuf::isCUDADeviceContext(mlir::Region &region) {
46+
bool cuf::isCUDADeviceContext(mlir::Region &region, bool isStdParEnabled) {
4747
if (region.getParentOfType<cuf::KernelOp>())
4848
return true;
4949
if (region.getParentOfType<mlir::acc::ComputeRegionOpInterface>())
@@ -56,6 +56,8 @@ bool cuf::isCUDADeviceContext(mlir::Region &region) {
5656
cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
5757
}
5858
}
59+
if (isStdParEnabled && region.getParentOfType<fir::DoConcurrentLoopOp>())
60+
return true;
5961
return false;
6062
}
6163

flang/lib/Semantics/assignment.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,20 @@ void AssignmentChecker::Enter(const parser::CUFKernelDoConstruct &x) {
260260
void AssignmentChecker::Leave(const parser::CUFKernelDoConstruct &) {
261261
--context_.value().deviceConstructDepth_;
262262
}
263+
void AssignmentChecker::Enter(const parser::DoConstruct &x) {
264+
if (x.IsDoConcurrent() &&
265+
context().foldingContext().languageFeatures().IsEnabled(
266+
common::LanguageFeature::StdPar)) {
267+
++context_.value().deviceConstructDepth_;
268+
}
269+
}
270+
void AssignmentChecker::Leave(const parser::DoConstruct &x) {
271+
if (x.IsDoConcurrent() &&
272+
context().foldingContext().languageFeatures().IsEnabled(
273+
common::LanguageFeature::StdPar)) {
274+
--context_.value().deviceConstructDepth_;
275+
}
276+
}
263277
static bool IsOpenACCComputeConstruct(const parser::OpenACCBlockConstruct &x) {
264278
const auto &beginBlockDirective =
265279
std::get<Fortran::parser::AccBeginBlockDirective>(x.t);

flang/lib/Semantics/assignment.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct MaskedElsewhereStmt;
2121
struct PointerAssignmentStmt;
2222
struct WhereConstructStmt;
2323
struct WhereStmt;
24+
struct DoConstruct;
2425
} // namespace Fortran::parser
2526

2627
namespace Fortran::semantics {
@@ -54,6 +55,8 @@ class AssignmentChecker : public virtual BaseChecker {
5455
void Leave(const parser::OpenACCCombinedConstruct &);
5556
void Enter(const parser::OpenACCLoopConstruct &);
5657
void Leave(const parser::OpenACCLoopConstruct &);
58+
void Enter(const parser::DoConstruct &);
59+
void Leave(const parser::DoConstruct &);
5760

5861
SemanticsContext &context();
5962

flang/lib/Semantics/semantics.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ static void WarnUndefinedFunctionResult(
197197

198198
using StatementSemanticsPass1 = ExprChecker;
199199
using StatementSemanticsPass2 = SemanticsVisitor<AllocateChecker,
200-
ArithmeticIfStmtChecker, AssignmentChecker, CaseChecker, CoarrayChecker,
201-
DataChecker, DeallocateChecker, DoForallChecker, IfStmtChecker, IoChecker,
202-
MiscChecker, NamelistChecker, NullifyChecker, PurityChecker,
203-
ReturnStmtChecker, SelectRankConstructChecker, SelectTypeChecker,
204-
StopChecker>;
200+
ArithmeticIfStmtChecker, CaseChecker, CoarrayChecker, DataChecker,
201+
DeallocateChecker, DoForallChecker, IfStmtChecker, IoChecker, MiscChecker,
202+
NamelistChecker, NullifyChecker, PurityChecker, ReturnStmtChecker,
203+
SelectRankConstructChecker, SelectTypeChecker, StopChecker>;
204+
using StatementSemanticsPass3 = SemanticsVisitor<AssignmentChecker>;
205205

206206
static bool PerformStatementSemantics(
207207
SemanticsContext &context, parser::Program &program) {
@@ -212,6 +212,7 @@ static bool PerformStatementSemantics(
212212
StatementSemanticsPass1{context}.Walk(program);
213213
StatementSemanticsPass2 pass2{context};
214214
pass2.Walk(program);
215+
StatementSemanticsPass3{context}.Walk(program);
215216
if (context.languageFeatures().IsEnabled(common::LanguageFeature::OpenACC)) {
216217
SemanticsVisitor<AccStructureChecker>{context}.Walk(program);
217218
}

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,19 @@ end subroutine
403403
! CHECK-LABEL: func.func @_QPsub20()
404404
! CHECK-NOT: cuf.data_transfer
405405
! CHECK: hlfir.assign
406+
407+
subroutine sub21()
408+
real, allocatable,device:: a(:,:), b(:,:)
409+
real:: s
410+
integer:: i,j,N=16
411+
allocate(a(N,N),b(N,N))
412+
do concurrent(i=1:N, j=1:N) reduce(+:s)
413+
b(i,j)=a(i,j)**2
414+
s=s+b(i,j)
415+
end do
416+
end subroutine
417+
418+
! CHECK-LABEL: func.func @_QPsub21()
419+
! CHECK: fir.do_concurrent.loop
420+
! CHECK-NOT: cuf.data_transfer
421+
! CHECK: hlfir.assign

flang/tools/bbc/bbc.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
223223
llvm::cl::desc("enable CUDA Fortran"),
224224
llvm::cl::init(false));
225225

226+
static llvm::cl::opt<bool> enableStdPar("stdpar",
227+
llvm::cl::desc("enable stdpar"),
228+
llvm::cl::init(false));
229+
226230
static llvm::cl::opt<bool>
227231
disableCUDAWarpFunction("fcuda-disable-warp-function",
228232
llvm::cl::desc("Disable CUDA Warp Function"),
@@ -608,6 +612,10 @@ int main(int argc, char **argv) {
608612
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
609613
}
610614

615+
if (enableStdPar) {
616+
options.features.Enable(Fortran::common::LanguageFeature::StdPar);
617+
}
618+
611619
if (disableCUDAWarpFunction) {
612620
options.features.Enable(
613621
Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);

0 commit comments

Comments
 (0)