Skip to content

Commit c4844a4

Browse files
committed
Move check to CUDAChecker
1 parent f6ae559 commit c4844a4

File tree

5 files changed

+75
-90
lines changed

5 files changed

+75
-90
lines changed

flang/lib/Semantics/assignment.cpp

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class AssignmentContext {
4242
void Analyze(const parser::AssignmentStmt &);
4343
void Analyze(const parser::PointerAssignmentStmt &);
4444
void Analyze(const parser::ConcurrentControl &);
45-
int deviceConstructDepth_{0};
4645
SemanticsContext &context() { return context_; }
4746

4847
private:
@@ -97,21 +96,6 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
9796
if (whereDepth_ > 0) {
9897
CheckShape(lhsLoc, &lhs);
9998
}
100-
if (context_.foldingContext().languageFeatures().IsEnabled(
101-
common::LanguageFeature::CUDA)) {
102-
const auto &scope{context_.FindScope(lhsLoc)};
103-
const Scope &progUnit{GetProgramUnitContaining(scope)};
104-
if (!IsCUDADeviceContext(&progUnit) && deviceConstructDepth_ == 0) {
105-
if (Fortran::evaluate::HasCUDADeviceAttrs(lhs) &&
106-
Fortran::evaluate::HasCUDAImplicitTransfer(rhs)) {
107-
if (GetNbOfCUDAManagedOrUnifiedSymbols(lhs) == 1 &&
108-
GetNbOfCUDAManagedOrUnifiedSymbols(rhs) == 1 &&
109-
GetNbOfCUDADeviceSymbols(rhs) == 1)
110-
return; // This is a special case handled on the host.
111-
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
112-
}
113-
}
114-
}
11599
}
116100
}
117101

@@ -254,60 +238,6 @@ void AssignmentChecker::Enter(const parser::MaskedElsewhereStmt &x) {
254238
void AssignmentChecker::Leave(const parser::MaskedElsewhereStmt &) {
255239
context_.value().PopWhereContext();
256240
}
257-
void AssignmentChecker::Enter(const parser::CUFKernelDoConstruct &x) {
258-
++context_.value().deviceConstructDepth_;
259-
}
260-
void AssignmentChecker::Leave(const parser::CUFKernelDoConstruct &) {
261-
--context_.value().deviceConstructDepth_;
262-
}
263-
void AssignmentChecker::Enter(const parser::DoConstruct &x) {
264-
if (x.IsDoConcurrent() &&
265-
context().foldingContext().languageFeatures().IsEnabled(
266-
common::LanguageFeature::StdPar)) {
267-
++context_.value().deviceConstructDepth_;
268-
}
269-
}
270-
void AssignmentChecker::Leave(const parser::DoConstruct &x) {
271-
if (x.IsDoConcurrent() &&
272-
context().foldingContext().languageFeatures().IsEnabled(
273-
common::LanguageFeature::StdPar)) {
274-
--context_.value().deviceConstructDepth_;
275-
}
276-
}
277-
static bool IsOpenACCComputeConstruct(const parser::OpenACCBlockConstruct &x) {
278-
const auto &beginBlockDirective =
279-
std::get<Fortran::parser::AccBeginBlockDirective>(x.t);
280-
const auto &blockDirective =
281-
std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
282-
if (blockDirective.v == llvm::acc::ACCD_parallel ||
283-
blockDirective.v == llvm::acc::ACCD_serial ||
284-
blockDirective.v == llvm::acc::ACCD_kernels) {
285-
return true;
286-
}
287-
return false;
288-
}
289-
void AssignmentChecker::Enter(const parser::OpenACCBlockConstruct &x) {
290-
if (IsOpenACCComputeConstruct(x)) {
291-
++context_.value().deviceConstructDepth_;
292-
}
293-
}
294-
void AssignmentChecker::Leave(const parser::OpenACCBlockConstruct &x) {
295-
if (IsOpenACCComputeConstruct(x)) {
296-
--context_.value().deviceConstructDepth_;
297-
}
298-
}
299-
void AssignmentChecker::Enter(const parser::OpenACCCombinedConstruct &) {
300-
++context_.value().deviceConstructDepth_;
301-
}
302-
void AssignmentChecker::Leave(const parser::OpenACCCombinedConstruct &) {
303-
--context_.value().deviceConstructDepth_;
304-
}
305-
void AssignmentChecker::Enter(const parser::OpenACCLoopConstruct &) {
306-
++context_.value().deviceConstructDepth_;
307-
}
308-
void AssignmentChecker::Leave(const parser::OpenACCLoopConstruct &) {
309-
--context_.value().deviceConstructDepth_;
310-
}
311241

312242
} // namespace Fortran::semantics
313243
template class Fortran::common::Indirection<

flang/lib/Semantics/assignment.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,6 @@ class AssignmentChecker : public virtual BaseChecker {
4747
void Leave(const parser::EndWhereStmt &);
4848
void Enter(const parser::MaskedElsewhereStmt &);
4949
void Leave(const parser::MaskedElsewhereStmt &);
50-
void Enter(const parser::CUFKernelDoConstruct &);
51-
void Leave(const parser::CUFKernelDoConstruct &);
52-
void Enter(const parser::OpenACCBlockConstruct &);
53-
void Leave(const parser::OpenACCBlockConstruct &);
54-
void Enter(const parser::OpenACCCombinedConstruct &);
55-
void Leave(const parser::OpenACCCombinedConstruct &);
56-
void Enter(const parser::OpenACCLoopConstruct &);
57-
void Leave(const parser::OpenACCLoopConstruct &);
58-
void Enter(const parser::DoConstruct &);
59-
void Leave(const parser::DoConstruct &);
6050

6151
SemanticsContext &context();
6252

flang/lib/Semantics/check-cuda.cpp

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,67 @@ void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
685685
std::get<std::list<parser::CUFReduction>>(directive.t)) {
686686
CheckReduce(context_, reduce);
687687
}
688-
inCUFKernelDoConstruct_ = true;
688+
++deviceConstructDepth_;
689+
}
690+
691+
static bool IsOpenACCComputeConstruct(const parser::OpenACCBlockConstruct &x) {
692+
const auto &beginBlockDirective =
693+
std::get<Fortran::parser::AccBeginBlockDirective>(x.t);
694+
const auto &blockDirective =
695+
std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
696+
if (blockDirective.v == llvm::acc::ACCD_parallel ||
697+
blockDirective.v == llvm::acc::ACCD_serial ||
698+
blockDirective.v == llvm::acc::ACCD_kernels) {
699+
return true;
700+
}
701+
return false;
689702
}
690703

691704
void CUDAChecker::Leave(const parser::CUFKernelDoConstruct &) {
692-
inCUFKernelDoConstruct_ = false;
705+
--deviceConstructDepth_;
706+
}
707+
void CUDAChecker::Enter(const parser::OpenACCBlockConstruct &x) {
708+
if (IsOpenACCComputeConstruct(x)) {
709+
++deviceConstructDepth_;
710+
}
711+
}
712+
void CUDAChecker::Leave(const parser::OpenACCBlockConstruct &x) {
713+
if (IsOpenACCComputeConstruct(x)) {
714+
--deviceConstructDepth_;
715+
}
716+
}
717+
void CUDAChecker::Enter(const parser::OpenACCCombinedConstruct &) {
718+
++deviceConstructDepth_;
719+
}
720+
void CUDAChecker::Leave(const parser::OpenACCCombinedConstruct &) {
721+
--deviceConstructDepth_;
722+
}
723+
void CUDAChecker::Enter(const parser::OpenACCLoopConstruct &) {
724+
++deviceConstructDepth_;
725+
}
726+
void CUDAChecker::Leave(const parser::OpenACCLoopConstruct &) {
727+
--deviceConstructDepth_;
728+
}
729+
void CUDAChecker::Enter(const parser::DoConstruct &x) {
730+
if (x.IsDoConcurrent() &&
731+
context_.foldingContext().languageFeatures().IsEnabled(
732+
common::LanguageFeature::StdPar)) {
733+
++deviceConstructDepth_;
734+
}
735+
}
736+
void CUDAChecker::Leave(const parser::DoConstruct &x) {
737+
if (x.IsDoConcurrent() &&
738+
context_.foldingContext().languageFeatures().IsEnabled(
739+
common::LanguageFeature::StdPar)) {
740+
--deviceConstructDepth_;
741+
}
693742
}
694743

695744
void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
696745
auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};
697746
const auto &scope{context_.FindScope(lhsLoc)};
698747
const Scope &progUnit{GetProgramUnitContaining(scope)};
699-
if (IsCUDADeviceContext(&progUnit) || inCUFKernelDoConstruct_) {
748+
if (IsCUDADeviceContext(&progUnit) || deviceConstructDepth_ > 0) {
700749
return; // Data transfer with assignment is only perform on host.
701750
}
702751

@@ -714,6 +763,15 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
714763
context_.Say(lhsLoc,
715764
"More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
716765
}
766+
767+
if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) &&
768+
Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
769+
if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 &&
770+
GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 &&
771+
GetNbOfCUDADeviceSymbols(assign->rhs) == 1)
772+
return; // This is a special case handled on the host.
773+
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
774+
}
717775
}
718776

719777
} // namespace Fortran::semantics

flang/lib/Semantics/check-cuda.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,18 @@ class CUDAChecker : public virtual BaseChecker {
4141
void Enter(const parser::CUFKernelDoConstruct &);
4242
void Leave(const parser::CUFKernelDoConstruct &);
4343
void Enter(const parser::AssignmentStmt &);
44+
void Enter(const parser::OpenACCBlockConstruct &);
45+
void Leave(const parser::OpenACCBlockConstruct &);
46+
void Enter(const parser::OpenACCCombinedConstruct &);
47+
void Leave(const parser::OpenACCCombinedConstruct &);
48+
void Enter(const parser::OpenACCLoopConstruct &);
49+
void Leave(const parser::OpenACCLoopConstruct &);
50+
void Enter(const parser::DoConstruct &);
51+
void Leave(const parser::DoConstruct &);
4452

4553
private:
4654
SemanticsContext &context_;
47-
bool inCUFKernelDoConstruct_ = false;
55+
int deviceConstructDepth_{0};
4856
};
4957

5058
bool CanonicalizeCUDA(parser::Program &);

flang/lib/Semantics/semantics.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ static void WarnUndefinedFunctionResult(
197197

198198
using StatementSemanticsPass1 = ExprChecker;
199199
using StatementSemanticsPass2 = SemanticsVisitor<AllocateChecker,
200-
ArithmeticIfStmtChecker, CaseChecker, CoarrayChecker, DataChecker,
201-
DeallocateChecker, DoForallChecker, IfStmtChecker, IoChecker, MiscChecker,
202-
NamelistChecker, NullifyChecker, PurityChecker, ReturnStmtChecker,
203-
SelectRankConstructChecker, SelectTypeChecker, StopChecker>;
204-
using StatementSemanticsPass3 = SemanticsVisitor<AssignmentChecker>;
200+
ArithmeticIfStmtChecker, AssignmentChecker, CaseChecker, CoarrayChecker,
201+
DataChecker, DeallocateChecker, DoForallChecker, IfStmtChecker, IoChecker,
202+
MiscChecker, NamelistChecker, NullifyChecker, PurityChecker,
203+
ReturnStmtChecker, SelectRankConstructChecker, SelectTypeChecker,
204+
StopChecker>;
205205

206206
static bool PerformStatementSemantics(
207207
SemanticsContext &context, parser::Program &program) {
@@ -212,7 +212,6 @@ static bool PerformStatementSemantics(
212212
StatementSemanticsPass1{context}.Walk(program);
213213
StatementSemanticsPass2 pass2{context};
214214
pass2.Walk(program);
215-
StatementSemanticsPass3{context}.Walk(program);
216215
if (context.languageFeatures().IsEnabled(common::LanguageFeature::OpenACC)) {
217216
SemanticsVisitor<AccStructureChecker>{context}.Walk(program);
218217
}

0 commit comments

Comments
 (0)