Skip to content

Commit b77f1b2

Browse files
committed
[OpenACC][CIR] Implement 'atomic capture' lowering
The 'atomic capture' variant of the `atomic` construct accepts either a single statement, or a compound statement containing two statements. Each of the statements it accepts meet a form of the previous read/write/update forms, or is a combination of two. The IR node for atomic capture takes two separate other acc.atomics, plus a terminator. This patch implements all of the lowering for these.
1 parent c4be17a commit b77f1b2

File tree

5 files changed

+899
-119
lines changed

5 files changed

+899
-119
lines changed

clang/include/clang/AST/StmtOpenACC.h

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -818,14 +818,57 @@ class OpenACCAtomicConstruct final
818818

819819
// A struct to represent a broken-down version of the associated statement,
820820
// providing the information specified in OpenACC3.3 Section 2.12.
821-
struct StmtInfo {
821+
struct SingleStmtInfo {
822+
// Holds the entire expression for this. In the case of a normal
823+
// read/write/update, this should just be the associated statement. in the
824+
// case of an update, this is going to be the sub-expression this
825+
// represents.
826+
const Expr *WholeExpr;
822827
const Expr *V;
823828
const Expr *X;
824829
// Listed as 'expr' in the standard, this is typically a generic expression
825830
// as a component.
826831
const Expr *RefExpr;
827-
// TODO: OpenACC: We should expand this as we're implementing the other
828-
// atomic construct kinds.
832+
static SingleStmtInfo Empty() {
833+
return {nullptr, nullptr, nullptr, nullptr};
834+
}
835+
836+
static SingleStmtInfo createRead(const Expr *WholeExpr, const Expr *V,
837+
const Expr *X) {
838+
return {WholeExpr, V, X, /*RefExpr=*/nullptr};
839+
}
840+
static SingleStmtInfo createWrite(const Expr *WholeExpr, const Expr *X,
841+
const Expr *RefExpr) {
842+
return {WholeExpr, /*V=*/nullptr, X, RefExpr};
843+
}
844+
static SingleStmtInfo createUpdate(const Expr *WholeExpr, const Expr *X) {
845+
return {WholeExpr, /*V=*/nullptr, X, /*RefExpr=*/nullptr};
846+
}
847+
};
848+
849+
struct StmtInfo {
850+
enum class StmtForm {
851+
Read,
852+
Write,
853+
Update,
854+
ReadWrite,
855+
ReadUpdate,
856+
UpdateRead
857+
} Form;
858+
SingleStmtInfo First, Second;
859+
860+
static StmtInfo createUpdateRead(SingleStmtInfo First,
861+
SingleStmtInfo Second) {
862+
return {StmtForm::UpdateRead, First, Second};
863+
}
864+
static StmtInfo createReadWrite(SingleStmtInfo First,
865+
SingleStmtInfo Second) {
866+
return {StmtForm::ReadWrite, First, Second};
867+
}
868+
static StmtInfo createReadUpdate(SingleStmtInfo First,
869+
SingleStmtInfo Second) {
870+
return {StmtForm::ReadUpdate, First, Second};
871+
}
829872
};
830873

831874
const StmtInfo getAssociatedStmtInfo() const;

clang/lib/AST/StmtOpenACC.cpp

Lines changed: 207 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -324,30 +324,207 @@ OpenACCAtomicConstruct *OpenACCAtomicConstruct::Create(
324324
return Inst;
325325
}
326326

327-
static std::pair<const Expr *, const Expr *> getBinaryOpArgs(const Expr *Op) {
327+
static std::optional<std::pair<const Expr *, const Expr *>>
328+
getBinaryAssignOpArgs(const Expr *Op, bool &isCompoundAssign) {
328329
if (const auto *BO = dyn_cast<BinaryOperator>(Op)) {
329-
assert(BO->isAssignmentOp());
330-
return {BO->getLHS(), BO->getRHS()};
330+
if (!BO->isAssignmentOp())
331+
return std::nullopt;
332+
isCompoundAssign = BO->isCompoundAssignmentOp();
333+
return std::pair<const Expr *, const Expr *>({BO->getLHS(), BO->getRHS()});
331334
}
332335

333-
const auto *OO = cast<CXXOperatorCallExpr>(Op);
334-
assert(OO->isAssignmentOp());
335-
return {OO->getArg(0), OO->getArg(1)};
336+
if (const auto *OO = dyn_cast<CXXOperatorCallExpr>(Op)) {
337+
if (!OO->isAssignmentOp())
338+
return std::nullopt;
339+
isCompoundAssign = OO->getOperator() != OO_Equal;
340+
return std::pair<const Expr *, const Expr *>(
341+
{OO->getArg(0), OO->getArg(1)});
342+
}
343+
return std::nullopt;
344+
}
345+
static std::optional<std::pair<const Expr *, const Expr *>>
346+
getBinaryAssignOpArgs(const Expr *Op) {
347+
bool isCompoundAssign;
348+
return getBinaryAssignOpArgs(Op, isCompoundAssign);
336349
}
337350

338-
static std::pair<bool, const Expr *> getUnaryOpArgs(const Expr *Op) {
351+
static std::optional<const Expr *> getUnaryOpArgs(const Expr *Op) {
339352
if (const auto *UO = dyn_cast<UnaryOperator>(Op))
340-
return {true, UO->getSubExpr()};
353+
return UO->getSubExpr();
341354

342355
if (const auto *OpCall = dyn_cast<CXXOperatorCallExpr>(Op)) {
343356
// Post-inc/dec have a second unused argument to differentiate it, so we
344357
// accept -- or ++ as unary, or any operator call with only 1 arg.
345358
if (OpCall->getNumArgs() == 1 || OpCall->getOperator() != OO_PlusPlus ||
346359
OpCall->getOperator() != OO_MinusMinus)
347-
return {true, OpCall->getArg(0)};
360+
return {OpCall->getArg(0)};
348361
}
349362

350-
return {false, nullptr};
363+
return std::nullopt;
364+
}
365+
366+
// Read is of the form `v = x;`, where both sides are scalar L-values. This is a
367+
// BinaryOperator or CXXOperatorCallExpr.
368+
static std::optional<OpenACCAtomicConstruct::SingleStmtInfo>
369+
getReadStmtInfo(const Expr *E, bool ForAtomicComputeSingleStmt = false) {
370+
std::optional<std::pair<const Expr *, const Expr *>> BinaryArgs =
371+
getBinaryAssignOpArgs(E);
372+
373+
if (!BinaryArgs)
374+
return std::nullopt;
375+
376+
// We want the L-value for each side, so we ignore implicit casts.
377+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createRead(
378+
E, BinaryArgs->first->IgnoreImpCasts(),
379+
BinaryArgs->second->IgnoreImpCasts());
380+
381+
// The atomic compute single-stmt variant has to do a 'fixup' step for the 'X'
382+
// value, since it is dependent on the RHS. So if we're in that version, we
383+
// skip the checks on X.
384+
if ((!ForAtomicComputeSingleStmt &&
385+
(!Res.X->isLValue() || !Res.X->getType()->isScalarType())) ||
386+
!Res.V->isLValue() || !Res.V->getType()->isScalarType())
387+
return std::nullopt;
388+
389+
return Res;
390+
}
391+
392+
// Write supports only the format 'x = expr', where the expression is scalar
393+
// type, and 'x' is a scalar l value. As above, this can come in 2 forms;
394+
// Binary Operator or CXXOperatorCallExpr.
395+
static std::optional<OpenACCAtomicConstruct::SingleStmtInfo>
396+
getWriteStmtInfo(const Expr *E) {
397+
std::optional<std::pair<const Expr *, const Expr *>> BinaryArgs =
398+
getBinaryAssignOpArgs(E);
399+
if (!BinaryArgs)
400+
return std::nullopt;
401+
// We want the L-value for ONLY the X side, so we ignore implicit casts. For
402+
// the right side (the expr), we emit it as an r-value so we need to
403+
// maintain implicit casts.
404+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createWrite(
405+
E, BinaryArgs->first->IgnoreImpCasts(), BinaryArgs->second);
406+
407+
if (!Res.X->isLValue() || !Res.X->getType()->isScalarType())
408+
return std::nullopt;
409+
return Res;
410+
}
411+
412+
static std::optional<OpenACCAtomicConstruct::SingleStmtInfo>
413+
getUpdateStmtInfo(const Expr *E) {
414+
std::optional<const Expr *> UnaryArgs = getUnaryOpArgs(E);
415+
if (UnaryArgs) {
416+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createUpdate(
417+
E, (*UnaryArgs)->IgnoreImpCasts());
418+
419+
if (!Res.X->isLValue() || !Res.X->getType()->isScalarType())
420+
return std::nullopt;
421+
422+
return Res;
423+
}
424+
425+
bool isRHSCompoundAssign = false;
426+
std::optional<std::pair<const Expr *, const Expr *>> BinaryArgs =
427+
getBinaryAssignOpArgs(E, isRHSCompoundAssign);
428+
if (!BinaryArgs)
429+
return std::nullopt;
430+
431+
auto Res = OpenACCAtomicConstruct::SingleStmtInfo::createUpdate(
432+
E, BinaryArgs->first->IgnoreImpCasts());
433+
434+
if (!Res.X->isLValue() || !Res.X->getType()->isScalarType())
435+
return std::nullopt;
436+
437+
// 'update' has to be either a compound-assignment operation, or
438+
// assignment-to-a-binary-op. Return nullopt if these are not the case.
439+
// If we are already compound-assign, we're done!
440+
if (isRHSCompoundAssign)
441+
return Res;
442+
443+
// else we have to check that we have a binary operator.
444+
const Expr *RHS = BinaryArgs->second->IgnoreImpCasts();
445+
446+
if (isa<BinaryOperator>(RHS))
447+
return Res;
448+
else if (const auto *OO = dyn_cast<CXXOperatorCallExpr>(RHS)) {
449+
if (OO->isInfixBinaryOp())
450+
return Res;
451+
}
452+
453+
return std::nullopt;
454+
}
455+
456+
static OpenACCAtomicConstruct::StmtInfo
457+
getCaptureStmtInfo(const Stmt *AssocStmt) {
458+
if (const auto *CmpdStmt = dyn_cast<CompoundStmt>(AssocStmt)) {
459+
// We checked during Sema to ensure we only have 2 statements here, and
460+
// that both are expressions, we can look at these to see what the valid
461+
// options are.
462+
const Expr *Stmt1 = cast<Expr>(*CmpdStmt->body().begin())->IgnoreImpCasts();
463+
const Expr *Stmt2 =
464+
cast<Expr>(*(CmpdStmt->body().begin() + 1))->IgnoreImpCasts();
465+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Read =
466+
getReadStmtInfo(Stmt1);
467+
468+
if (Read) {
469+
// READ : WRITE
470+
// v = x; x = expr
471+
// READ : UPDATE
472+
// v = x; x binop = expr
473+
// v = x; x = x binop expr
474+
// v = x; x = expr binop x
475+
// v = x; x++
476+
// v = x; ++x
477+
// v = x; x--
478+
// v = x; --x
479+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Update =
480+
getUpdateStmtInfo(Stmt2);
481+
if (Update)
482+
return OpenACCAtomicConstruct::StmtInfo::createReadUpdate(*Read,
483+
*Update);
484+
485+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Write =
486+
getWriteStmtInfo(Stmt2);
487+
return OpenACCAtomicConstruct::StmtInfo::createReadWrite(*Read, *Write);
488+
}
489+
// UPDATE: READ
490+
// x binop = expr; v = x
491+
// x = x binop expr; v = x
492+
// x = expr binop x ; v = x
493+
// ++ x; v = x
494+
// x++; v = x
495+
// --x; v = x
496+
// x--; v = x
497+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Update =
498+
getUpdateStmtInfo(Stmt1);
499+
Read = getReadStmtInfo(Stmt2);
500+
501+
return OpenACCAtomicConstruct::StmtInfo::createUpdateRead(*Update, *Read);
502+
} else {
503+
// All of the possible forms (listed below) that are writable as a single
504+
// line are expressed as an update, then as a read. We should be able to
505+
// just run these two in the right order.
506+
// UPDATE: READ
507+
// v = x++;
508+
// v = x--;
509+
// v = ++x;
510+
// v = --x;
511+
// v = x binop=expr
512+
// v = x = x binop expr
513+
// v = x = expr binop x
514+
515+
const Expr *E = cast<const Expr>(AssocStmt);
516+
517+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Read =
518+
getReadStmtInfo(E, /*ForAtomicComputeSingleStmt=*/true);
519+
std::optional<OpenACCAtomicConstruct::SingleStmtInfo> Update =
520+
getUpdateStmtInfo(Read->X);
521+
522+
// Fixup this, since the 'X' for the read is the result after write, but is
523+
// the same value as the LHS-most variable of the update(its X).
524+
Read->X = Update->X;
525+
return OpenACCAtomicConstruct::StmtInfo::createUpdateRead(*Update, *Read);
526+
}
527+
return {};
351528
}
352529

353530
const OpenACCAtomicConstruct::StmtInfo
@@ -357,48 +534,28 @@ OpenACCAtomicConstruct::getAssociatedStmtInfo() const {
357534
// asserts to ensure we don't get off into the weeds.
358535
assert(getAssociatedStmt() && "invalid associated stmt?");
359536

360-
const Expr *AssocStmt = cast<const Expr>(getAssociatedStmt());
361537
switch (AtomicKind) {
362-
case OpenACCAtomicKind::Capture:
363-
assert(false && "Only 'read'/'write'/'update' have been implemented here");
364-
return {};
365-
case OpenACCAtomicKind::Read: {
366-
// Read only supports the format 'v = x'; where both sides are a scalar
367-
// expression. This can come in 2 forms; BinaryOperator or
368-
// CXXOperatorCallExpr (rarely).
369-
std::pair<const Expr *, const Expr *> BinaryArgs =
370-
getBinaryOpArgs(AssocStmt);
371-
// We want the L-value for each side, so we ignore implicit casts.
372-
return {BinaryArgs.first->IgnoreImpCasts(),
373-
BinaryArgs.second->IgnoreImpCasts(), /*expr=*/nullptr};
374-
}
375-
case OpenACCAtomicKind::Write: {
376-
// Write supports only the format 'x = expr', where the expression is scalar
377-
// type, and 'x' is a scalar l value. As above, this can come in 2 forms;
378-
// Binary Operator or CXXOperatorCallExpr.
379-
std::pair<const Expr *, const Expr *> BinaryArgs =
380-
getBinaryOpArgs(AssocStmt);
381-
// We want the L-value for ONLY the X side, so we ignore implicit casts. For
382-
// the right side (the expr), we emit it as an r-value so we need to
383-
// maintain implicit casts.
384-
return {/*v=*/nullptr, BinaryArgs.first->IgnoreImpCasts(),
385-
BinaryArgs.second};
386-
}
538+
case OpenACCAtomicKind::Read:
539+
return OpenACCAtomicConstruct::StmtInfo{
540+
OpenACCAtomicConstruct::StmtInfo::StmtForm::Read,
541+
*getReadStmtInfo(cast<const Expr>(getAssociatedStmt())),
542+
OpenACCAtomicConstruct::SingleStmtInfo::Empty()};
543+
544+
case OpenACCAtomicKind::Write:
545+
return OpenACCAtomicConstruct::StmtInfo{
546+
OpenACCAtomicConstruct::StmtInfo::StmtForm::Write,
547+
*getWriteStmtInfo(cast<const Expr>(getAssociatedStmt())),
548+
OpenACCAtomicConstruct::SingleStmtInfo::Empty()};
549+
387550
case OpenACCAtomicKind::None:
388-
case OpenACCAtomicKind::Update: {
389-
std::pair<bool, const Expr *> UnaryArgs = getUnaryOpArgs(AssocStmt);
390-
if (UnaryArgs.first)
391-
return {/*v=*/nullptr, UnaryArgs.second->IgnoreImpCasts(),
392-
/*expr=*/nullptr};
393-
394-
std::pair<const Expr *, const Expr *> BinaryArgs =
395-
getBinaryOpArgs(AssocStmt);
396-
// For binary args, we just store the RHS as an expression (in the
397-
// expression slot), since the codegen just wants the whole thing for a
398-
// recipe.
399-
return {/*v=*/nullptr, BinaryArgs.first->IgnoreImpCasts(),
400-
BinaryArgs.second};
401-
}
551+
case OpenACCAtomicKind::Update:
552+
return OpenACCAtomicConstruct::StmtInfo{
553+
OpenACCAtomicConstruct::StmtInfo::StmtForm::Update,
554+
*getUpdateStmtInfo(cast<const Expr>(getAssociatedStmt())),
555+
OpenACCAtomicConstruct::SingleStmtInfo::Empty()};
556+
557+
case OpenACCAtomicKind::Capture:
558+
return getCaptureStmtInfo(getAssociatedStmt());
402559
}
403560

404561
llvm_unreachable("unknown OpenACC atomic kind");

0 commit comments

Comments
 (0)