Skip to content

Commit 1fca50d

Browse files
authored
[OpenACC] Partial Reduction recipe Lowering (#155635)
This patch implements basic reduction recipe lowering, plus adds a bunch of tests for it that should be meaningful later. At the moment, all this does is ensure that we get the init 'alloca' set right (the actual initializer isn't done correctly yet, and will be in a followup), an empty combiner (though the type of certain operations probably has to be different as well, when we get to those), and a full-destruction, as we already have the infrastructure for it.
1 parent 41fed2d commit 1fca50d

20 files changed

+6701
-17
lines changed

clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp

Lines changed: 139 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,8 @@ class OpenACCClauseCIREmitter final
357357
}
358358

359359
template <typename RecipeTy>
360-
std::string getRecipeName(SourceRange loc, QualType baseType) {
360+
std::string getRecipeName(SourceRange loc, QualType baseType,
361+
OpenACCReductionOperator reductionOp) {
361362
std::string recipeName;
362363
{
363364
llvm::raw_string_ostream stream(recipeName);
@@ -371,12 +372,40 @@ class OpenACCClauseCIREmitter final
371372
} else if constexpr (std::is_same_v<RecipeTy,
372373
mlir::acc::ReductionRecipeOp>) {
373374
stream << "reduction_";
374-
// TODO: OpenACC: once we have this part implemented, we can remove the
375-
// SourceRange `loc` variable from this function. We don't have the
376-
// reduction operation here well enough to know how to spell this
377-
// correctly (+ == 'add', etc), so when we implement 'reduction' we have
378-
// to do that here.
379-
cgf.cgm.errorNYI(loc, "OpenACC reduction recipe name creation");
375+
// Values here are a little weird (for bitwise and/or is 'i' prefix, and
376+
// logical ops with 'l'), but are chosen to be the same as the MLIR
377+
// dialect names as well as to match the Flang versions of these.
378+
switch (reductionOp) {
379+
case OpenACCReductionOperator::Addition:
380+
stream << "add_";
381+
break;
382+
case OpenACCReductionOperator::Multiplication:
383+
stream << "mul_";
384+
break;
385+
case OpenACCReductionOperator::Max:
386+
stream << "max_";
387+
break;
388+
case OpenACCReductionOperator::Min:
389+
stream << "min_";
390+
break;
391+
case OpenACCReductionOperator::BitwiseAnd:
392+
stream << "iand_";
393+
break;
394+
case OpenACCReductionOperator::BitwiseOr:
395+
stream << "ior_";
396+
break;
397+
case OpenACCReductionOperator::BitwiseXOr:
398+
stream << "xor_";
399+
break;
400+
case OpenACCReductionOperator::And:
401+
stream << "land_";
402+
break;
403+
case OpenACCReductionOperator::Or:
404+
stream << "lor_";
405+
break;
406+
case OpenACCReductionOperator::Invalid:
407+
llvm_unreachable("invalid reduction operator");
408+
}
380409
} else {
381410
static_assert(!sizeof(RecipeTy), "Unknown Recipe op kind");
382411
}
@@ -419,7 +448,9 @@ class OpenACCClauseCIREmitter final
419448
}
420449

421450
// Create the 'init' section of the recipe, including the 'copy' section for
422-
// 'firstprivate'.
451+
// 'firstprivate'. Note that this function is not 'insertion point' clean, in
452+
// that it alters the insertion point to be inside of the 'destroy' section of
453+
// the recipe, but doesn't restore it aftewards.
423454
template <typename RecipeTy>
424455
void createRecipeInitCopy(mlir::Location loc, mlir::Location locEnd,
425456
SourceRange exprRange, mlir::Value mainOp,
@@ -485,6 +516,27 @@ class OpenACCClauseCIREmitter final
485516
}
486517
}
487518

519+
// This function generates the 'combiner' section for a reduction recipe. Note
520+
// that this function is not 'insertion point' clean, in that it alters the
521+
// insertion point to be inside of the 'combiner' section of the recipe, but
522+
// doesn't restore it aftewards.
523+
void createReductionRecipeCombiner(mlir::Location loc, mlir::Location locEnd,
524+
mlir::Value mainOp,
525+
mlir::acc::ReductionRecipeOp recipe) {
526+
mlir::Block *block = builder.createBlock(
527+
&recipe.getCombinerRegion(), recipe.getCombinerRegion().end(),
528+
{mainOp.getType(), mainOp.getType()}, {loc, loc});
529+
builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
530+
531+
mlir::BlockArgument lhsArg = block->getArgument(0);
532+
533+
mlir::acc::YieldOp::create(builder, locEnd, lhsArg);
534+
}
535+
536+
// This function generates the 'destroy' section for a recipe. Note
537+
// that this function is not 'insertion point' clean, in that it alters the
538+
// insertion point to be inside of the 'destroy' section of the recipe, but
539+
// doesn't restore it aftewards.
488540
void createRecipeDestroySection(mlir::Location loc, mlir::Location locEnd,
489541
mlir::Value mainOp, CharUnits alignment,
490542
QualType baseType,
@@ -502,30 +554,68 @@ class OpenACCClauseCIREmitter final
502554
mlir::acc::YieldOp::create(builder, locEnd);
503555
}
504556

557+
mlir::acc::ReductionOperator convertReductionOp(OpenACCReductionOperator op) {
558+
switch (op) {
559+
case OpenACCReductionOperator::Addition:
560+
return mlir::acc::ReductionOperator::AccAdd;
561+
case OpenACCReductionOperator::Multiplication:
562+
return mlir::acc::ReductionOperator::AccMul;
563+
case OpenACCReductionOperator::Max:
564+
return mlir::acc::ReductionOperator::AccMax;
565+
case OpenACCReductionOperator::Min:
566+
return mlir::acc::ReductionOperator::AccMin;
567+
case OpenACCReductionOperator::BitwiseAnd:
568+
return mlir::acc::ReductionOperator::AccIand;
569+
case OpenACCReductionOperator::BitwiseOr:
570+
return mlir::acc::ReductionOperator::AccIor;
571+
case OpenACCReductionOperator::BitwiseXOr:
572+
return mlir::acc::ReductionOperator::AccXor;
573+
case OpenACCReductionOperator::And:
574+
return mlir::acc::ReductionOperator::AccLand;
575+
case OpenACCReductionOperator::Or:
576+
return mlir::acc::ReductionOperator::AccLor;
577+
case OpenACCReductionOperator::Invalid:
578+
llvm_unreachable("invalid reduction operator");
579+
}
580+
581+
llvm_unreachable("invalid reduction operator");
582+
}
583+
505584
template <typename RecipeTy>
506585
RecipeTy getOrCreateRecipe(ASTContext &astCtx, const Expr *varRef,
507586
const VarDecl *varRecipe, const VarDecl *temporary,
587+
OpenACCReductionOperator reductionOp,
508588
DeclContext *dc, QualType baseType,
509589
mlir::Value mainOp) {
510590
mlir::ModuleOp mod = builder.getBlock()
511591
->getParent()
512592
->template getParentOfType<mlir::ModuleOp>();
513593

514-
std::string recipeName =
515-
getRecipeName<RecipeTy>(varRef->getSourceRange(), baseType);
594+
std::string recipeName = getRecipeName<RecipeTy>(varRef->getSourceRange(),
595+
baseType, reductionOp);
516596
if (auto recipe = mod.lookupSymbol<RecipeTy>(recipeName))
517597
return recipe;
518598

519599
mlir::Location loc = cgf.cgm.getLoc(varRef->getBeginLoc());
520600
mlir::Location locEnd = cgf.cgm.getLoc(varRef->getEndLoc());
521601

522602
mlir::OpBuilder modBuilder(mod.getBodyRegion());
523-
auto recipe =
524-
RecipeTy::create(modBuilder, loc, recipeName, mainOp.getType());
603+
RecipeTy recipe;
604+
605+
if constexpr (std::is_same_v<RecipeTy, mlir::acc::ReductionRecipeOp>) {
606+
recipe = RecipeTy::create(modBuilder, loc, recipeName, mainOp.getType(),
607+
convertReductionOp(reductionOp));
608+
} else {
609+
recipe = RecipeTy::create(modBuilder, loc, recipeName, mainOp.getType());
610+
}
525611

526612
createRecipeInitCopy(loc, locEnd, varRef->getSourceRange(), mainOp, recipe,
527613
varRecipe, temporary);
528614

615+
if constexpr (std::is_same_v<RecipeTy, mlir::acc::ReductionRecipeOp>) {
616+
createReductionRecipeCombiner(loc, locEnd, mainOp, recipe);
617+
}
618+
529619
if (varRecipe && varRecipe->needsDestruction(cgf.getContext()))
530620
createRecipeDestroySection(loc, locEnd, mainOp,
531621
cgf.getContext().getDeclAlign(varRecipe),
@@ -1166,6 +1256,8 @@ class OpenACCClauseCIREmitter final
11661256
mlir::OpBuilder::InsertionGuard guardCase(builder);
11671257
auto recipe = getOrCreateRecipe<mlir::acc::PrivateRecipeOp>(
11681258
cgf.getContext(), varExpr, varRecipe, /*temporary=*/nullptr,
1259+
OpenACCReductionOperator::Invalid,
1260+
11691261
Decl::castToDeclContext(cgf.curFuncDecl), opInfo.baseType,
11701262
privateOp.getResult());
11711263
// TODO: OpenACC: The dialect is going to change in the near future to
@@ -1200,7 +1292,7 @@ class OpenACCClauseCIREmitter final
12001292
mlir::OpBuilder::InsertionGuard guardCase(builder);
12011293
auto recipe = getOrCreateRecipe<mlir::acc::FirstprivateRecipeOp>(
12021294
cgf.getContext(), varExpr, varRecipe.RecipeDecl,
1203-
varRecipe.InitFromTemporary,
1295+
varRecipe.InitFromTemporary, OpenACCReductionOperator::Invalid,
12041296
Decl::castToDeclContext(cgf.curFuncDecl), opInfo.baseType,
12051297
firstPrivateOp.getResult());
12061298

@@ -1219,6 +1311,40 @@ class OpenACCClauseCIREmitter final
12191311
llvm_unreachable("Unknown construct kind in VisitFirstPrivateClause");
12201312
}
12211313
}
1314+
1315+
void VisitReductionClause(const OpenACCReductionClause &clause) {
1316+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
1317+
mlir::acc::LoopOp>) {
1318+
for (const auto [varExpr, varRecipe] :
1319+
llvm::zip_equal(clause.getVarList(), clause.getRecipes())) {
1320+
CIRGenFunction::OpenACCDataOperandInfo opInfo =
1321+
cgf.getOpenACCDataOperandInfo(varExpr);
1322+
1323+
auto reductionOp = mlir::acc::ReductionOp::create(
1324+
builder, opInfo.beginLoc, opInfo.varValue, /*structured=*/true,
1325+
/*implicit=*/false, opInfo.name, opInfo.bounds);
1326+
reductionOp.setDataClause(mlir::acc::DataClause::acc_reduction);
1327+
1328+
{
1329+
mlir::OpBuilder::InsertionGuard guardCase(builder);
1330+
1331+
auto recipe = getOrCreateRecipe<mlir::acc::ReductionRecipeOp>(
1332+
cgf.getContext(), varExpr, varRecipe.RecipeDecl,
1333+
/*temporary=*/nullptr, clause.getReductionOp(),
1334+
Decl::castToDeclContext(cgf.curFuncDecl), opInfo.baseType,
1335+
reductionOp.getResult());
1336+
1337+
operation.addReduction(builder.getContext(), reductionOp, recipe);
1338+
}
1339+
}
1340+
} else if constexpr (isCombinedType<OpTy>) {
1341+
// Despite this being valid on ParallelOp or SerialOp, combined type
1342+
// applies to the 'loop'.
1343+
applyToLoopOp(clause);
1344+
} else {
1345+
llvm_unreachable("Unknown construct kind in VisitReductionClause");
1346+
}
1347+
}
12221348
};
12231349

12241350
template <typename OpTy>

0 commit comments

Comments
 (0)