Skip to content

Commit 89d8d98

Browse files
erichkeaneLukacma
authored andcommitted
[OpenACC][CIR] Reduction combiner lowering for min/max (llvm#163656)
These two are lowered as if they are the expression: LHS = (LHS < RHS ) ? RHS : LHS; and LHS = (LHS < RHS ) ? LHS : RHS; This patch generates these expressions and ensures they are properly emitted into IR. Note: this is dependent on llvm#163580 and cannot be merged until that one is (or the tests will fail).
1 parent b124f68 commit 89d8d98

25 files changed

+4504
-137
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13676,6 +13676,9 @@ def err_acc_reduction_recipe_no_op
1367613676
"not have a valid operation available">;
1367713677
def note_acc_reduction_recipe_noop_field
1367813678
: Note<"while forming combiner for compound type %0">;
13679+
def note_acc_reduction_combiner_forming
13680+
: Note<"while forming %select{|binary operator '%1'|conditional "
13681+
"operator|final assignment operator}0">;
1367913682

1368013683
// AMDGCN builtins diagnostics
1368113684
def err_amdgcn_load_lds_size_invalid_value : Error<"invalid size value">;

clang/lib/CIR/CodeGen/CIRGenOpenACCRecipe.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,15 +590,18 @@ void OpenACCRecipeBuilderBase::createReductionRecipeCombiner(
590590
} else {
591591
// else we have to handle each individual field after after a
592592
// get-element.
593+
const CIRGenRecordLayout &layout =
594+
cgf.cgm.getTypes().getCIRGenRecordLayout(rd);
593595
for (const auto &[field, combiner] :
594596
llvm::zip_equal(rd->fields(), combinerRecipes)) {
595597
mlir::Type fieldType = cgf.convertType(field->getType());
596598
auto fieldPtr = cir::PointerType::get(fieldType);
599+
unsigned fieldIndex = layout.getCIRFieldNo(field);
597600

598601
mlir::Value lhsField = builder.createGetMember(
599-
loc, fieldPtr, lhsArg, field->getName(), field->getFieldIndex());
602+
loc, fieldPtr, lhsArg, field->getName(), fieldIndex);
600603
mlir::Value rhsField = builder.createGetMember(
601-
loc, fieldPtr, rhsArg, field->getName(), field->getFieldIndex());
604+
loc, fieldPtr, rhsArg, field->getName(), fieldIndex);
602605

603606
emitSingleCombiner(lhsField, rhsField, combiner);
604607
}

clang/lib/Sema/SemaOpenACC.cpp

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2996,6 +2996,8 @@ bool SemaOpenACC::CreateReductionCombinerRecipe(
29962996

29972997
case OpenACCReductionOperator::Max:
29982998
case OpenACCReductionOperator::Min:
2999+
BinOp = BinaryOperatorKind::BO_LT;
3000+
break;
29993001
case OpenACCReductionOperator::And:
30003002
case OpenACCReductionOperator::Or:
30013003
// We just want a 'NYI' error in the backend, so leave an empty combiner
@@ -3011,26 +3013,80 @@ bool SemaOpenACC::CreateReductionCombinerRecipe(
30113013

30123014
assert(!VarTy->isArrayType() && "Only 1 level of array allowed");
30133015

3016+
enum class CombinerFailureKind {
3017+
None = 0,
3018+
BinOp = 1,
3019+
Conditional = 2,
3020+
Assignment = 3,
3021+
};
3022+
3023+
auto genCombiner = [&, this](DeclRefExpr *LHSDRE, DeclRefExpr *RHSDRE)
3024+
-> std::pair<ExprResult, CombinerFailureKind> {
3025+
ExprResult BinOpRes =
3026+
SemaRef.BuildBinOp(SemaRef.getCurScope(), Loc, BinOp, LHSDRE, RHSDRE,
3027+
/*ForFoldExpr=*/false);
3028+
switch (ReductionOperator) {
3029+
case OpenACCReductionOperator::Addition:
3030+
case OpenACCReductionOperator::Multiplication:
3031+
case OpenACCReductionOperator::BitwiseAnd:
3032+
case OpenACCReductionOperator::BitwiseOr:
3033+
case OpenACCReductionOperator::BitwiseXOr:
3034+
// These 5 are simple and are being done as compound operators, so we can
3035+
// immediately quit here.
3036+
return {BinOpRes, BinOpRes.isUsable() ? CombinerFailureKind::None
3037+
: CombinerFailureKind::BinOp};
3038+
case OpenACCReductionOperator::Max:
3039+
case OpenACCReductionOperator::Min: {
3040+
// These are done as:
3041+
// LHS = (LHS < RHS) ? LHS : RHS; and LHS = (LHS < RHS) ? RHS : LHS;
3042+
//
3043+
// The BinOpRes should have been created with the less-than, so we just
3044+
// have to build the conditional and assignment.
3045+
if (!BinOpRes.isUsable())
3046+
return {BinOpRes, CombinerFailureKind::BinOp};
3047+
3048+
// Create the correct conditional operator, swapping the results
3049+
// (true/false value) depending on min/max.
3050+
ExprResult CondRes;
3051+
if (ReductionOperator == OpenACCReductionOperator::Min)
3052+
CondRes = SemaRef.ActOnConditionalOp(Loc, Loc, BinOpRes.get(), LHSDRE,
3053+
RHSDRE);
3054+
else
3055+
CondRes = SemaRef.ActOnConditionalOp(Loc, Loc, BinOpRes.get(), RHSDRE,
3056+
LHSDRE);
3057+
3058+
if (!CondRes.isUsable())
3059+
return {CondRes, CombinerFailureKind::Conditional};
3060+
3061+
// Build assignment.
3062+
ExprResult Assignment = SemaRef.BuildBinOp(SemaRef.getCurScope(), Loc,
3063+
BinaryOperatorKind::BO_Assign,
3064+
LHSDRE, CondRes.get(),
3065+
/*ForFoldExpr=*/false);
3066+
return {Assignment, Assignment.isUsable()
3067+
? CombinerFailureKind::None
3068+
: CombinerFailureKind::Assignment};
3069+
}
3070+
case OpenACCReductionOperator::And:
3071+
case OpenACCReductionOperator::Or:
3072+
llvm_unreachable("And/Or not implemented, but should fail earlier");
3073+
case OpenACCReductionOperator::Invalid:
3074+
llvm_unreachable("Invalid should have been caught above");
3075+
}
3076+
};
3077+
30143078
auto tryCombiner = [&, this](DeclRefExpr *LHSDRE, DeclRefExpr *RHSDRE,
30153079
bool IncludeTrap) {
3016-
// TODO: OpenACC: we have to figure out based on the bin-op how to do the
3017-
// ones that we can't just use compound operators for. So &&, ||, max, and
3018-
// min aren't really clear what we could do here.
30193080
if (IncludeTrap) {
30203081
// Trap all of the errors here, we'll emit our own at the end.
30213082
Sema::TentativeAnalysisScope Trap{SemaRef};
3022-
3023-
return SemaRef.BuildBinOp(SemaRef.getCurScope(), Loc, BinOp, LHSDRE,
3024-
RHSDRE,
3025-
/*ForFoldExpr=*/false);
3026-
} else {
3027-
return SemaRef.BuildBinOp(SemaRef.getCurScope(), Loc, BinOp, LHSDRE,
3028-
RHSDRE,
3029-
/*ForFoldExpr=*/false);
3083+
return genCombiner(LHSDRE, RHSDRE);
30303084
}
3085+
return genCombiner(LHSDRE, RHSDRE);
30313086
};
30323087

30333088
struct CombinerAttemptTy {
3089+
CombinerFailureKind FailKind;
30343090
VarDecl *LHS;
30353091
DeclRefExpr *LHSDRE;
30363092
VarDecl *RHS;
@@ -3058,9 +3114,11 @@ bool SemaOpenACC::CreateReductionCombinerRecipe(
30583114
RHSDecl->getBeginLoc()},
30593115
Ty, clang::VK_LValue, RHSDecl, nullptr, NOUR_None);
30603116

3061-
ExprResult BinOpResult = tryCombiner(LHSDRE, RHSDRE, /*IncludeTrap=*/true);
3117+
std::pair<ExprResult, CombinerFailureKind> BinOpResult =
3118+
tryCombiner(LHSDRE, RHSDRE, /*IncludeTrap=*/true);
30623119

3063-
return {LHSDecl, LHSDRE, RHSDecl, RHSDRE, BinOpResult.get()};
3120+
return {BinOpResult.second, LHSDecl, LHSDRE, RHSDecl, RHSDRE,
3121+
BinOpResult.first.get()};
30643122
};
30653123

30663124
CombinerAttemptTy TopLevelCombinerInfo = formCombiner(VarTy);
@@ -3081,12 +3139,20 @@ bool SemaOpenACC::CreateReductionCombinerRecipe(
30813139
}
30823140
}
30833141

3142+
auto EmitFailureNote = [&](CombinerFailureKind CFK) {
3143+
if (CFK == CombinerFailureKind::BinOp)
3144+
return Diag(Loc, diag::note_acc_reduction_combiner_forming)
3145+
<< CFK << BinaryOperator::getOpcodeStr(BinOp);
3146+
return Diag(Loc, diag::note_acc_reduction_combiner_forming) << CFK;
3147+
};
3148+
30843149
// Since the 'root' level didn't fail, the only thing that could be successful
30853150
// is a struct that we decompose on its individual fields.
30863151

30873152
RecordDecl *RD = VarTy->getAsRecordDecl();
30883153
if (!RD) {
30893154
Diag(Loc, diag::err_acc_reduction_recipe_no_op) << VarTy;
3155+
EmitFailureNote(TopLevelCombinerInfo.FailKind);
30903156
tryCombiner(TopLevelCombinerInfo.LHSDRE, TopLevelCombinerInfo.RHSDRE,
30913157
/*IncludeTrap=*/false);
30923158
return true;
@@ -3098,6 +3164,7 @@ bool SemaOpenACC::CreateReductionCombinerRecipe(
30983164
if (!FieldCombinerInfo.Op || FieldCombinerInfo.Op->containsErrors()) {
30993165
Diag(Loc, diag::err_acc_reduction_recipe_no_op) << FD->getType();
31003166
Diag(FD->getBeginLoc(), diag::note_acc_reduction_recipe_noop_field) << RD;
3167+
EmitFailureNote(FieldCombinerInfo.FailKind);
31013168
tryCombiner(FieldCombinerInfo.LHSDRE, FieldCombinerInfo.RHSDRE,
31023169
/*IncludeTrap=*/false);
31033170
return true;

0 commit comments

Comments
 (0)