Skip to content

Commit 82a4277

Browse files
authored
[OpenACC][NFC] AST changes for Reduction combiner (llvm#162573)
This is the first patch of a handful to get the reduction combiner recipe lowering properly. THIS patch is NFC as it doesn't actually change anything except the structure of the AST. For each 'combiner' recipe we need a 'LHS' 'RHS' and expression to represent the operation. Each var-reference can have 1 or more combiners. IF it is a plain scalar, or a struct with the proper operator, or an array of either of those, there will be 1. HOWEVER, aggregates without the proper operator are supposed to be broken down and done from their elements (which can only be scalars). In this case, we will represent 1 'combiner' recipe per field-decl. This patch only puts the infrastructure in place to do so, future patches wll do the work to fill this in.
1 parent 31103ef commit 82a4277

File tree

5 files changed

+57
-8
lines changed

5 files changed

+57
-8
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,13 +1280,31 @@ class OpenACCCreateClause final
12801280
// 'main' declaration used for initializaiton, which is fixed.
12811281
struct OpenACCReductionRecipe {
12821282
VarDecl *AllocaDecl;
1283-
// TODO: OpenACC: this should eventually have the operations here too.
12841283

1285-
OpenACCReductionRecipe(VarDecl *A) : AllocaDecl(A) {}
1284+
// A combiner recipe is represented by an operation expression. However, in
1285+
// order to generate these properly, we have to make up a LHS and a RHS
1286+
// expression for the purposes of generation.
1287+
struct CombinerRecipe {
1288+
VarDecl *LHS;
1289+
VarDecl *RHS;
1290+
Expr *Op;
1291+
};
1292+
1293+
// Contains a collection of the recipe elements we need for the combiner:
1294+
// -For Scalars, there will be 1 element, just the combiner for that scalar.
1295+
// -For a struct with a valid operator, this will be 1 element, just that
1296+
// call.
1297+
// -For a struct without the operator, this will be 1 element per field, which
1298+
// should be the combiner for that element.
1299+
// -For an array of any of the above, it will be the above for the element.
1300+
llvm::SmallVector<CombinerRecipe, 1> CombinerRecipes;
1301+
1302+
OpenACCReductionRecipe(VarDecl *A, llvm::ArrayRef<CombinerRecipe> Combiners)
1303+
: AllocaDecl(A), CombinerRecipes(Combiners) {}
12861304

12871305
bool isSet() const { return AllocaDecl; }
12881306
static OpenACCReductionRecipe Empty() {
1289-
return OpenACCReductionRecipe(/*AllocaDecl=*/nullptr);
1307+
return OpenACCReductionRecipe(/*AllocaDecl=*/nullptr, {});
12901308
}
12911309
};
12921310

clang/lib/AST/StmtProfile.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2769,10 +2769,19 @@ void OpenACCClauseProfiler::VisitReductionClause(
27692769

27702770
for (auto &Recipe : Clause.getRecipes()) {
27712771
Profiler.VisitDecl(Recipe.AllocaDecl);
2772+
27722773
// TODO: OpenACC: Make sure we remember to update this when we figure out
27732774
// what we're adding for the operation recipe, in the meantime, a static
27742775
// assert will make sure we don't add something.
2775-
static_assert(sizeof(OpenACCReductionRecipe) == sizeof(int *));
2776+
static_assert(sizeof(OpenACCReductionRecipe::CombinerRecipe) ==
2777+
3 * sizeof(int *));
2778+
for (auto &CombinerRecipe : Recipe.CombinerRecipes) {
2779+
if (CombinerRecipe.Op) {
2780+
Profiler.VisitDecl(CombinerRecipe.LHS);
2781+
Profiler.VisitDecl(CombinerRecipe.RHS);
2782+
Profiler.VisitStmt(CombinerRecipe.Op);
2783+
}
2784+
}
27762785
}
27772786
}
27782787

clang/lib/Sema/SemaOpenACC.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2946,5 +2946,5 @@ OpenACCReductionRecipe SemaOpenACC::CreateReductionInitRecipe(
29462946
AllocaDecl->setInit(Init.get());
29472947
AllocaDecl->setInitStyle(VarDecl::CallInit);
29482948
}
2949-
return OpenACCReductionRecipe(AllocaDecl);
2949+
return OpenACCReductionRecipe(AllocaDecl, {});
29502950
}

clang/lib/Serialization/ASTReader.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13009,9 +13009,22 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
1300913009
llvm::SmallVector<OpenACCReductionRecipe> RecipeList;
1301013010

1301113011
for (unsigned I = 0; I < VarList.size(); ++I) {
13012-
static_assert(sizeof(OpenACCReductionRecipe) == sizeof(int *));
1301313012
VarDecl *Recipe = readDeclAs<VarDecl>();
13014-
RecipeList.push_back({Recipe});
13013+
13014+
static_assert(sizeof(OpenACCReductionRecipe::CombinerRecipe) ==
13015+
3 * sizeof(int *));
13016+
13017+
llvm::SmallVector<OpenACCReductionRecipe::CombinerRecipe> Combiners;
13018+
unsigned NumCombiners = readInt();
13019+
for (unsigned I = 0; I < NumCombiners; ++I) {
13020+
VarDecl *LHS = readDeclAs<VarDecl>();
13021+
VarDecl *RHS = readDeclAs<VarDecl>();
13022+
Expr *Op = readExpr();
13023+
13024+
Combiners.push_back({LHS, RHS, Op});
13025+
}
13026+
13027+
RecipeList.push_back({Recipe, Combiners});
1301513028
}
1301613029

1301713030
return OpenACCReductionClause::Create(getContext(), BeginLoc, LParenLoc, Op,

clang/lib/Serialization/ASTWriter.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8925,8 +8925,17 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
89258925
writeOpenACCVarList(RC);
89268926

89278927
for (const OpenACCReductionRecipe &R : RC->getRecipes()) {
8928-
static_assert(sizeof(OpenACCReductionRecipe) == 1 * sizeof(int *));
89298928
AddDeclRef(R.AllocaDecl);
8929+
8930+
static_assert(sizeof(OpenACCReductionRecipe::CombinerRecipe) ==
8931+
3 * sizeof(int *));
8932+
writeUInt32(R.CombinerRecipes.size());
8933+
8934+
for (auto &CombinerRecipe : R.CombinerRecipes) {
8935+
AddDeclRef(CombinerRecipe.LHS);
8936+
AddDeclRef(CombinerRecipe.RHS);
8937+
AddStmt(CombinerRecipe.Op);
8938+
}
89308939
}
89318940
return;
89328941
}

0 commit comments

Comments
 (0)