Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5139,6 +5139,12 @@ static constexpr StringRef getOpenMPVariantManglingSeparatorStr() {
bool IsArmStreamingFunction(const FunctionDecl *FD,
bool IncludeLocallyStreaming);

/// Returns whether the given FunctionDecl has Arm ZA state.
bool hasArmZAState(const FunctionDecl *FD);

/// Returns whether the given FunctionDecl has Arm ZT0 state.
bool hasArmZT0State(const FunctionDecl *FD);

} // namespace clang

#endif // LLVM_CLANG_AST_DECL_H
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -3864,6 +3864,9 @@ def err_sme_definition_using_za_in_non_sme_target : Error<
"function using ZA state requires 'sme'">;
def err_sme_definition_using_zt0_in_non_sme2_target : Error<
"function using ZT0 state requires 'sme2'">;
def err_sme_openmp_captured_region : Error<
"OpenMP captured regions are not yet supported in "
"%select{streaming functions|functions with ZA state|functions with ZT0 state}0">;
def warn_sme_streaming_pass_return_vl_to_non_streaming : Warning<
"%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a function with a different"
" streaming-mode is undefined behaviour when the streaming and non-streaming vector lengths are different at runtime">,
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5844,3 +5844,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD,

return false;
}

bool clang::hasArmZAState(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
}

bool clang::hasArmZT0State(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
}
14 changes: 0 additions & 14 deletions clang/lib/Sema/SemaARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
return true;
}

static bool hasArmZAState(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
}

static bool hasArmZT0State(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
}

static ArmSMEState getSMEState(unsigned BuiltinID) {
switch (BuiltinID) {
default:
Expand Down
21 changes: 21 additions & 0 deletions clang/lib/Sema/SemaStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4568,9 +4568,27 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
return false;
}

static std::optional<int>
isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
return {};
if (const FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
return /* in streaming functions */ 0;
if (hasArmZAState(FD))
return /* in functions with ZA state */ 1;
if (hasArmZT0State(FD))
return /* in fuctions with ZT0 state */ 2;
}
return {};
}

void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
unsigned NumParams) {
if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;

CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams);

Expand Down Expand Up @@ -4602,6 +4620,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
ArrayRef<CapturedParamNameType> Params,
unsigned OpenMPCaptureLevel) {
if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;

CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size());

Expand Down
81 changes: 81 additions & 0 deletions clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify %s
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify=expected-cpp -x c++ %s

int compute(int);

void streaming_openmp_captured_region(int * out) __arm_streaming {
// expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
#pragma omp parallel for num_threads(32)
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

__arm_locally_streaming void locally_streaming_openmp_captured_region(int * out) {
// expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
#pragma omp parallel for num_threads(32)
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

void za_state_captured_region(int * out) __arm_inout("za") {
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
#pragma omp parallel for num_threads(32)
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

__arm_new("za") void new_za_state_captured_region(int * out) {
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
#pragma omp parallel for num_threads(32)
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

void zt0_state_openmp_captured_region(int * out) __arm_inout("zt0") {
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
#pragma omp parallel for num_threads(32)
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

__arm_new("zt0") void new_zt0_state_openmp_captured_region(int * out) {
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
#pragma omp parallel for num_threads(32)
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

/// OpenMP directives that don't create a captured region are okay:

void streaming_function_openmp(int * out) __arm_streaming __arm_inout("za", "zt0") {
#pragma omp unroll full
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

__arm_locally_streaming void locally_streaming_openmp(int * out) __arm_inout("za", "zt0") {
#pragma omp unroll full
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}

__arm_new("za", "zt0") void arm_new_openmp(int * out) {
#pragma omp unroll full
for (int ci = 0; ci < 8; ci++) {
out[ci] = compute(ci);
}
}
Loading