Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5139,6 +5139,12 @@ static constexpr StringRef getOpenMPVariantManglingSeparatorStr() {
bool IsArmStreamingFunction(const FunctionDecl *FD,
bool IncludeLocallyStreaming);

/// Returns whether the given FunctionDecl has Arm ZA state.
bool hasArmZAState(const FunctionDecl *FD);

/// Returns whether the given FunctionDecl has Arm ZT0 state.
bool hasArmZT0State(const FunctionDecl *FD);

} // namespace clang

#endif // LLVM_CLANG_AST_DECL_H
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -3864,6 +3864,9 @@ def err_sme_definition_using_za_in_non_sme_target : Error<
"function using ZA state requires 'sme'">;
def err_sme_definition_using_zt0_in_non_sme2_target : Error<
"function using ZT0 state requires 'sme2'">;
def err_sme_openmp_captured_region : Error<
"OpenMP captured regions are not yet supported in "
"%select{streaming functions|functions with ZA state|functions with ZT0 state}0">;
def warn_sme_streaming_pass_return_vl_to_non_streaming : Warning<
"%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a function with a different"
" streaming-mode is undefined behaviour when the streaming and non-streaming vector lengths are different at runtime">,
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5844,3 +5844,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD,

return false;
}

bool clang::hasArmZAState(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
}

bool clang::hasArmZT0State(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
}
14 changes: 0 additions & 14 deletions clang/lib/Sema/SemaARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
return true;
}

static bool hasArmZAState(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
}

static bool hasArmZT0State(const FunctionDecl *FD) {
const auto *T = FD->getType()->getAs<FunctionProtoType>();
return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
FunctionType::ARM_None) ||
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
}

static ArmSMEState getSMEState(unsigned BuiltinID) {
switch (BuiltinID) {
default:
Expand Down
22 changes: 22 additions & 0 deletions clang/lib/Sema/SemaStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4568,9 +4568,28 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
return false;
}

static std::optional<int>
isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
return {};
FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true);
if (!FD)
return {};
if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
return /* in streaming functions */ 0;
if (hasArmZAState(FD))
return /* in functions with ZA state */ 1;
if (hasArmZT0State(FD))
return /* in fuctions with ZT0 state */ 2;
return {};
}

void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
unsigned NumParams) {
if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;

CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams);

Expand Down Expand Up @@ -4602,6 +4621,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
CapturedRegionKind Kind,
ArrayRef<CapturedParamNameType> Params,
unsigned OpenMPCaptureLevel) {
if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;

CapturedDecl *CD = nullptr;
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size());

Expand Down
68 changes: 68 additions & 0 deletions clang/test/Sema/aarch64-sme-attrs-openmp-captured-region.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify %s
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify=expected-cpp -x c++ %s

int compute(int);

void streaming_openmp_captured_region(int* out) __arm_streaming
{
// expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
#pragma omp parallel for num_threads(32)
for(int ci =0;ci< 8;ci++)
{
out[ci] =compute(ci);
}
}

__arm_locally_streaming void locally_streaming_openmp_captured_region(int* out)
{
// expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
#pragma omp parallel for num_threads(32)
for(int ci =0;ci< 8;ci++)
{
out[ci] = compute(ci);
}
}

void za_state_captured_region(int* out) __arm_inout("za")
{
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
#pragma omp parallel for num_threads(32)
for(int ci =0;ci< 8;ci++)
{
out[ci] =compute(ci);
}
}

void zt0_state_openmp_captured_region(int* out) __arm_inout("zt0")
{
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
#pragma omp parallel for num_threads(32)
for(int ci =0;ci< 8;ci++)
{
out[ci] = compute(ci);
}
}

/// OpenMP directives that don't create a captured region are okay:

void streaming_function_openmp(int* out) __arm_streaming __arm_inout("za", "zt0")
{
#pragma omp unroll full
for(int ci =0;ci< 8;ci++)
{
out[ci] =compute(ci);
}
}

__arm_locally_streaming void locally_streaming_openmp(int* out) __arm_inout("za", "zt0")
{
#pragma omp unroll full
for(int ci =0;ci< 8;ci++)
{
out[ci] = compute(ci);
}
}
Loading