diff --git a/flang/include/flang/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h index d1dbaefcd81d0..93ab2e350d035 100644 --- a/flang/include/flang/Lower/DirectivesCommon.h +++ b/flang/include/flang/Lower/DirectivesCommon.h @@ -46,520 +46,6 @@ namespace Fortran { namespace lower { -/// Populates \p hint and \p memoryOrder with appropriate clause information -/// if present on atomic construct. -static inline void genOmpAtomicHintAndMemoryOrderClauses( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAtomicClauseList &clauseList, - mlir::IntegerAttr &hint, - mlir::omp::ClauseMemoryOrderKindAttr &memoryOrder) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - for (const Fortran::parser::OmpAtomicClause &clause : clauseList.v) { - common::visit( - common::visitors{ - [&](const parser::OmpMemoryOrderClause &s) { - auto kind = common::visit( - common::visitors{ - [&](const parser::OmpClause::AcqRel &) { - return mlir::omp::ClauseMemoryOrderKind::Acq_rel; - }, - [&](const parser::OmpClause::Acquire &) { - return mlir::omp::ClauseMemoryOrderKind::Acquire; - }, - [&](const parser::OmpClause::Relaxed &) { - return mlir::omp::ClauseMemoryOrderKind::Relaxed; - }, - [&](const parser::OmpClause::Release &) { - return mlir::omp::ClauseMemoryOrderKind::Release; - }, - [&](const parser::OmpClause::SeqCst &) { - return mlir::omp::ClauseMemoryOrderKind::Seq_cst; - }, - [&](auto &&) -> mlir::omp::ClauseMemoryOrderKind { - llvm_unreachable("Unexpected clause"); - }, - }, - s.v.u); - memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( - firOpBuilder.getContext(), kind); - }, - [&](const parser::OmpHintClause &s) { - const auto *expr = Fortran::semantics::GetExpr(s.v); - uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr); - hint = firOpBuilder.getI64IntegerAttr(hintExprValue); - }, - [&](const parser::OmpFailClause &) {}, - }, - clause.u); - } -} - -template -static void processOmpAtomicTODO(mlir::Type elementType, - [[maybe_unused]] mlir::Location loc) { - if (!elementType) - return; - if constexpr (std::is_same()) { - assert(fir::isa_trivial(fir::unwrapRefType(elementType)) && - "is supported type for omp atomic"); - } -} - -/// Used to generate atomic.read operation which is created in existing -/// location set by builder. -template -static inline void genOmpAccAtomicCaptureStatement( - Fortran::lower::AbstractConverter &converter, mlir::Value fromAddress, - mlir::Value toAddress, - [[maybe_unused]] const AtomicListT *leftHandClauseList, - [[maybe_unused]] const AtomicListT *rightHandClauseList, - mlir::Type elementType, mlir::Location loc) { - // Generate `atomic.read` operation for atomic assigment statements - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - processOmpAtomicTODO(elementType, loc); - - if constexpr (std::is_same()) { - // If no hint clause is specified, the effect is as if - // hint(omp_sync_hint_none) had been specified. - mlir::IntegerAttr hint = nullptr; - - mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; - if (leftHandClauseList) - genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, - hint, memoryOrder); - if (rightHandClauseList) - genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, - hint, memoryOrder); - firOpBuilder.create( - loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType), hint, - memoryOrder); - } else { - firOpBuilder.create( - loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType)); - } -} - -/// Used to generate atomic.write operation which is created in existing -/// location set by builder. -template -static inline void genOmpAccAtomicWriteStatement( - Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr, - mlir::Value rhsExpr, [[maybe_unused]] const AtomicListT *leftHandClauseList, - [[maybe_unused]] const AtomicListT *rightHandClauseList, mlir::Location loc, - mlir::Value *evaluatedExprValue = nullptr) { - // Generate `atomic.write` operation for atomic assignment statements - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); - // Create a conversion outside the capture block. - auto insertionPoint = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp()); - rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr); - firOpBuilder.restoreInsertionPoint(insertionPoint); - - processOmpAtomicTODO(varType, loc); - - if constexpr (std::is_same()) { - // If no hint clause is specified, the effect is as if - // hint(omp_sync_hint_none) had been specified. - mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; - if (leftHandClauseList) - genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, - hint, memoryOrder); - if (rightHandClauseList) - genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, - hint, memoryOrder); - firOpBuilder.create(loc, lhsAddr, rhsExpr, hint, - memoryOrder); - } else { - firOpBuilder.create(loc, lhsAddr, rhsExpr); - } -} - -/// Used to generate atomic.update operation which is created in existing -/// location set by builder. -template -static inline void genOmpAccAtomicUpdateStatement( - Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr, - mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable, - const Fortran::parser::Expr &assignmentStmtExpr, - [[maybe_unused]] const AtomicListT *leftHandClauseList, - [[maybe_unused]] const AtomicListT *rightHandClauseList, mlir::Location loc, - mlir::Operation *atomicCaptureOp = nullptr) { - // Generate `atomic.update` operation for atomic assignment statements - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - - // Create the omp.atomic.update or acc.atomic.update operation - // - // func.func @_QPsb() { - // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} - // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"} - // %2 = fir.load %1 : !fir.ref - // omp.atomic.update %0 : !fir.ref { - // ^bb0(%arg0: i32): - // %3 = arith.addi %arg0, %2 : i32 - // omp.yield(%3 : i32) - // } - // return - // } - - auto getArgExpression = - [](std::list::const_iterator it) { - const auto &arg{std::get((*it).t)}; - const auto *parserExpr{ - std::get_if>(&arg.u)}; - return parserExpr; - }; - - // Lower any non atomic sub-expression before the atomic operation, and - // map its lowered value to the semantic representation. - Fortran::lower::ExprToValueMap exprValueOverrides; - // Max and min intrinsics can have a list of Args. Hence we need a list - // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted. - llvm::SmallVector nonAtomicSubExprs; - Fortran::common::visit( - Fortran::common::visitors{ - [&](const common::Indirection &funcRef) - -> void { - const auto &args{std::get>( - funcRef.value().v.t)}; - std::list::const_iterator beginIt = - args.begin(); - std::list::const_iterator endIt = args.end(); - const auto *exprFirst{getArgExpression(beginIt)}; - if (exprFirst && exprFirst->value().source == - assignmentStmtVariable.GetSource()) { - // Add everything except the first - beginIt++; - } else { - // Add everything except the last - endIt--; - } - std::list::const_iterator it; - for (it = beginIt; it != endIt; it++) { - const common::Indirection *expr = - getArgExpression(it); - if (expr) - nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr)); - } - }, - [&](const auto &op) -> void { - using T = std::decay_t; - if constexpr (std::is_base_of< - Fortran::parser::Expr::IntrinsicBinary, - T>::value) { - const auto &exprLeft{std::get<0>(op.t)}; - const auto &exprRight{std::get<1>(op.t)}; - if (exprLeft.value().source == assignmentStmtVariable.GetSource()) - nonAtomicSubExprs.push_back( - Fortran::semantics::GetExpr(exprRight)); - else - nonAtomicSubExprs.push_back( - Fortran::semantics::GetExpr(exprLeft)); - } - }, - }, - assignmentStmtExpr.u); - StatementContext nonAtomicStmtCtx; - if (!nonAtomicSubExprs.empty()) { - // Generate non atomic part before all the atomic operations. - auto insertionPoint = firOpBuilder.saveInsertionPoint(); - if (atomicCaptureOp) - firOpBuilder.setInsertionPoint(atomicCaptureOp); - mlir::Value nonAtomicVal; - for (auto *nonAtomicSubExpr : nonAtomicSubExprs) { - nonAtomicVal = fir::getBase(converter.genExprValue( - currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx)); - exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal); - } - if (atomicCaptureOp) - firOpBuilder.restoreInsertionPoint(insertionPoint); - } - - mlir::Operation *atomicUpdateOp = nullptr; - if constexpr (std::is_same()) { - // If no hint clause is specified, the effect is as if - // hint(omp_sync_hint_none) had been specified. - mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; - if (leftHandClauseList) - genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, - hint, memoryOrder); - if (rightHandClauseList) - genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, - hint, memoryOrder); - atomicUpdateOp = firOpBuilder.create( - currentLocation, lhsAddr, hint, memoryOrder); - } else { - atomicUpdateOp = firOpBuilder.create( - currentLocation, lhsAddr); - } - - processOmpAtomicTODO(varType, loc); - - llvm::SmallVector varTys = {varType}; - llvm::SmallVector locs = {currentLocation}; - firOpBuilder.createBlock(&atomicUpdateOp->getRegion(0), {}, varTys, locs); - mlir::Value val = - fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0)); - - exprValueOverrides.try_emplace( - Fortran::semantics::GetExpr(assignmentStmtVariable), val); - { - // statement context inside the atomic block. - converter.overrideExprValues(&exprValueOverrides); - Fortran::lower::StatementContext atomicStmtCtx; - mlir::Value rhsExpr = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx)); - mlir::Value convertResult = - firOpBuilder.createConvert(currentLocation, varType, rhsExpr); - if constexpr (std::is_same()) { - firOpBuilder.create(currentLocation, convertResult); - } else { - firOpBuilder.create(currentLocation, convertResult); - } - converter.resetExprOverrides(); - } - firOpBuilder.setInsertionPointAfter(atomicUpdateOp); -} - -/// Processes an atomic construct with write clause. -template -void genOmpAccAtomicWrite(Fortran::lower::AbstractConverter &converter, - const AtomicT &atomicWrite, mlir::Location loc) { - const AtomicListT *rightHandClauseList = nullptr; - const AtomicListT *leftHandClauseList = nullptr; - if constexpr (std::is_same()) { - // Get the address of atomic read operands. - rightHandClauseList = &std::get<2>(atomicWrite.t); - leftHandClauseList = &std::get<0>(atomicWrite.t); - } - - const Fortran::parser::AssignmentStmt &stmt = - std::get>( - atomicWrite.t) - .statement; - const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v; - Fortran::lower::StatementContext stmtCtx; - // Get the value and address of atomic write operands. - mlir::Value rhsExpr = - fir::getBase(converter.genExprValue(assign.rhs, stmtCtx)); - mlir::Value lhsAddr = - fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx)); - genOmpAccAtomicWriteStatement(converter, lhsAddr, rhsExpr, leftHandClauseList, - rightHandClauseList, loc); -} - -/// Processes an atomic construct with read clause. -template -void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter, - const AtomicT &atomicRead, mlir::Location loc) { - const AtomicListT *rightHandClauseList = nullptr; - const AtomicListT *leftHandClauseList = nullptr; - if constexpr (std::is_same()) { - // Get the address of atomic read operands. - rightHandClauseList = &std::get<2>(atomicRead.t); - leftHandClauseList = &std::get<0>(atomicRead.t); - } - - const auto &assignmentStmtExpr = std::get( - std::get>( - atomicRead.t) - .statement.t); - const auto &assignmentStmtVariable = std::get( - std::get>( - atomicRead.t) - .statement.t); - - Fortran::lower::StatementContext stmtCtx; - const Fortran::semantics::SomeExpr &fromExpr = - *Fortran::semantics::GetExpr(assignmentStmtExpr); - mlir::Type elementType = converter.genType(fromExpr); - mlir::Value fromAddress = - fir::getBase(converter.genExprAddr(fromExpr, stmtCtx)); - mlir::Value toAddress = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); - genOmpAccAtomicCaptureStatement(converter, fromAddress, toAddress, - leftHandClauseList, rightHandClauseList, - elementType, loc); -} - -/// Processes an atomic construct with update clause. -template -void genOmpAccAtomicUpdate(Fortran::lower::AbstractConverter &converter, - const AtomicT &atomicUpdate, mlir::Location loc) { - const AtomicListT *rightHandClauseList = nullptr; - const AtomicListT *leftHandClauseList = nullptr; - if constexpr (std::is_same()) { - // Get the address of atomic read operands. - rightHandClauseList = &std::get<2>(atomicUpdate.t); - leftHandClauseList = &std::get<0>(atomicUpdate.t); - } - - const auto &assignmentStmtExpr = std::get( - std::get>( - atomicUpdate.t) - .statement.t); - const auto &assignmentStmtVariable = std::get( - std::get>( - atomicUpdate.t) - .statement.t); - - Fortran::lower::StatementContext stmtCtx; - mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); - mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); - genOmpAccAtomicUpdateStatement( - converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr, - leftHandClauseList, rightHandClauseList, loc); -} - -/// Processes an atomic construct with no clause - which implies update clause. -template -void genOmpAtomic(Fortran::lower::AbstractConverter &converter, - const AtomicT &atomicConstruct, mlir::Location loc) { - const AtomicListT &atomicClauseList = - std::get(atomicConstruct.t); - const auto &assignmentStmtExpr = std::get( - std::get>( - atomicConstruct.t) - .statement.t); - const auto &assignmentStmtVariable = std::get( - std::get>( - atomicConstruct.t) - .statement.t); - Fortran::lower::StatementContext stmtCtx; - mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( - *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); - mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); - // If atomic-clause is not present on the construct, the behaviour is as if - // the update clause is specified (for both OpenMP and OpenACC). - genOmpAccAtomicUpdateStatement( - converter, lhsAddr, varType, assignmentStmtVariable, assignmentStmtExpr, - &atomicClauseList, nullptr, loc); -} - -/// Processes an atomic construct with capture clause. -template -void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter, - const AtomicT &atomicCapture, mlir::Location loc) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - const Fortran::parser::AssignmentStmt &stmt1 = - std::get(atomicCapture.t).v.statement; - const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v; - const auto &stmt1Var{std::get(stmt1.t)}; - const auto &stmt1Expr{std::get(stmt1.t)}; - const Fortran::parser::AssignmentStmt &stmt2 = - std::get(atomicCapture.t).v.statement; - const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v; - const auto &stmt2Var{std::get(stmt2.t)}; - const auto &stmt2Expr{std::get(stmt2.t)}; - - // Pre-evaluate expressions to be used in the various operations inside - // `atomic.capture` since it is not desirable to have anything other than - // a `atomic.read`, `atomic.write`, or `atomic.update` operation - // inside `atomic.capture` - Fortran::lower::StatementContext stmtCtx; - // LHS evaluations are common to all combinations of `atomic.capture` - mlir::Value stmt1LHSArg = - fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx)); - mlir::Value stmt2LHSArg = - fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx)); - - // Type information used in generation of `atomic.update` operation - mlir::Type stmt1VarType = - fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType(); - mlir::Type stmt2VarType = - fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType(); - - mlir::Operation *atomicCaptureOp = nullptr; - if constexpr (std::is_same()) { - mlir::IntegerAttr hint = nullptr; - mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; - const AtomicListT &rightHandClauseList = std::get<2>(atomicCapture.t); - const AtomicListT &leftHandClauseList = std::get<0>(atomicCapture.t); - genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, - memoryOrder); - genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, - memoryOrder); - atomicCaptureOp = - firOpBuilder.create(loc, hint, memoryOrder); - } else { - atomicCaptureOp = firOpBuilder.create(loc); - } - - firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0))); - mlir::Block &block = atomicCaptureOp->getRegion(0).back(); - firOpBuilder.setInsertionPointToStart(&block); - if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) { - if (Fortran::semantics::checkForSymbolMatch(stmt2)) { - // Atomic capture construct is of the form [capture-stmt, update-stmt] - const Fortran::semantics::SomeExpr &fromExpr = - *Fortran::semantics::GetExpr(stmt1Expr); - mlir::Type elementType = converter.genType(fromExpr); - genOmpAccAtomicCaptureStatement( - converter, stmt2LHSArg, stmt1LHSArg, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, elementType, loc); - genOmpAccAtomicUpdateStatement( - converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp); - } else { - // Atomic capture construct is of the form [capture-stmt, write-stmt] - firOpBuilder.setInsertionPoint(atomicCaptureOp); - mlir::Value stmt2RHSArg = - fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx)); - firOpBuilder.setInsertionPointToStart(&block); - const Fortran::semantics::SomeExpr &fromExpr = - *Fortran::semantics::GetExpr(stmt1Expr); - mlir::Type elementType = converter.genType(fromExpr); - genOmpAccAtomicCaptureStatement( - converter, stmt2LHSArg, stmt1LHSArg, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, elementType, loc); - genOmpAccAtomicWriteStatement( - converter, stmt2LHSArg, stmt2RHSArg, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, loc); - } - } else { - // Atomic capture construct is of the form [update-stmt, capture-stmt] - const Fortran::semantics::SomeExpr &fromExpr = - *Fortran::semantics::GetExpr(stmt2Expr); - mlir::Type elementType = converter.genType(fromExpr); - genOmpAccAtomicUpdateStatement( - converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp); - genOmpAccAtomicCaptureStatement( - converter, stmt1LHSArg, stmt2LHSArg, - /*leftHandClauseList=*/nullptr, - /*rightHandClauseList=*/nullptr, elementType, loc); - } - firOpBuilder.setInsertionPointToEnd(&block); - if constexpr (std::is_same()) { - firOpBuilder.create(loc); - } else { - firOpBuilder.create(loc); - } - firOpBuilder.setInsertionPointToStart(&block); -} - /// Create empty blocks for the current region. /// These blocks replace blocks parented to an enclosing region. template diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 418bf4ee3d15f..e6175ebda40b2 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -375,6 +375,310 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) { llvm::report_fatal_error("Could not find symbol"); } +/// Used to generate atomic.read operation which is created in existing +/// location set by builder. +static inline void +genAtomicCaptureStatement(Fortran::lower::AbstractConverter &converter, + mlir::Value fromAddress, mlir::Value toAddress, + mlir::Type elementType, mlir::Location loc) { + // Generate `atomic.read` operation for atomic assigment statements + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + firOpBuilder.create( + loc, fromAddress, toAddress, mlir::TypeAttr::get(elementType)); +} + +/// Used to generate atomic.write operation which is created in existing +/// location set by builder. +static inline void +genAtomicWriteStatement(Fortran::lower::AbstractConverter &converter, + mlir::Value lhsAddr, mlir::Value rhsExpr, + mlir::Location loc, + mlir::Value *evaluatedExprValue = nullptr) { + // Generate `atomic.write` operation for atomic assignment statements + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); + // Create a conversion outside the capture block. + auto insertionPoint = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp()); + rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr); + firOpBuilder.restoreInsertionPoint(insertionPoint); + + firOpBuilder.create(loc, lhsAddr, rhsExpr); +} + +/// Used to generate atomic.update operation which is created in existing +/// location set by builder. +static inline void genAtomicUpdateStatement( + Fortran::lower::AbstractConverter &converter, mlir::Value lhsAddr, + mlir::Type varType, const Fortran::parser::Variable &assignmentStmtVariable, + const Fortran::parser::Expr &assignmentStmtExpr, mlir::Location loc, + mlir::Operation *atomicCaptureOp = nullptr) { + // Generate `atomic.update` operation for atomic assignment statements + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + + // Create the omp.atomic.update or acc.atomic.update operation + // + // func.func @_QPsb() { + // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} + // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"} + // %2 = fir.load %1 : !fir.ref + // omp.atomic.update %0 : !fir.ref { + // ^bb0(%arg0: i32): + // %3 = arith.addi %arg0, %2 : i32 + // omp.yield(%3 : i32) + // } + // return + // } + + auto getArgExpression = + [](std::list::const_iterator it) { + const auto &arg{std::get((*it).t)}; + const auto *parserExpr{ + std::get_if>( + &arg.u)}; + return parserExpr; + }; + + // Lower any non atomic sub-expression before the atomic operation, and + // map its lowered value to the semantic representation. + Fortran::lower::ExprToValueMap exprValueOverrides; + // Max and min intrinsics can have a list of Args. Hence we need a list + // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted. + llvm::SmallVector nonAtomicSubExprs; + Fortran::common::visit( + Fortran::common::visitors{ + [&](const Fortran::common::Indirection< + Fortran::parser::FunctionReference> &funcRef) -> void { + const auto &args{ + std::get>( + funcRef.value().v.t)}; + std::list::const_iterator beginIt = + args.begin(); + std::list::const_iterator endIt = + args.end(); + const auto *exprFirst{getArgExpression(beginIt)}; + if (exprFirst && exprFirst->value().source == + assignmentStmtVariable.GetSource()) { + // Add everything except the first + beginIt++; + } else { + // Add everything except the last + endIt--; + } + std::list::const_iterator it; + for (it = beginIt; it != endIt; it++) { + const Fortran::common::Indirection *expr = + getArgExpression(it); + if (expr) + nonAtomicSubExprs.push_back(Fortran::semantics::GetExpr(*expr)); + } + }, + [&](const auto &op) -> void { + using T = std::decay_t; + if constexpr (std::is_base_of< + Fortran::parser::Expr::IntrinsicBinary, + T>::value) { + const auto &exprLeft{std::get<0>(op.t)}; + const auto &exprRight{std::get<1>(op.t)}; + if (exprLeft.value().source == assignmentStmtVariable.GetSource()) + nonAtomicSubExprs.push_back( + Fortran::semantics::GetExpr(exprRight)); + else + nonAtomicSubExprs.push_back( + Fortran::semantics::GetExpr(exprLeft)); + } + }, + }, + assignmentStmtExpr.u); + Fortran::lower::StatementContext nonAtomicStmtCtx; + if (!nonAtomicSubExprs.empty()) { + // Generate non atomic part before all the atomic operations. + auto insertionPoint = firOpBuilder.saveInsertionPoint(); + if (atomicCaptureOp) + firOpBuilder.setInsertionPoint(atomicCaptureOp); + mlir::Value nonAtomicVal; + for (auto *nonAtomicSubExpr : nonAtomicSubExprs) { + nonAtomicVal = fir::getBase(converter.genExprValue( + currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx)); + exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal); + } + if (atomicCaptureOp) + firOpBuilder.restoreInsertionPoint(insertionPoint); + } + + mlir::Operation *atomicUpdateOp = nullptr; + atomicUpdateOp = + firOpBuilder.create(currentLocation, lhsAddr); + + llvm::SmallVector varTys = {varType}; + llvm::SmallVector locs = {currentLocation}; + firOpBuilder.createBlock(&atomicUpdateOp->getRegion(0), {}, varTys, locs); + mlir::Value val = + fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0)); + + exprValueOverrides.try_emplace( + Fortran::semantics::GetExpr(assignmentStmtVariable), val); + { + // statement context inside the atomic block. + converter.overrideExprValues(&exprValueOverrides); + Fortran::lower::StatementContext atomicStmtCtx; + mlir::Value rhsExpr = fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx)); + mlir::Value convertResult = + firOpBuilder.createConvert(currentLocation, varType, rhsExpr); + firOpBuilder.create(currentLocation, convertResult); + converter.resetExprOverrides(); + } + firOpBuilder.setInsertionPointAfter(atomicUpdateOp); +} + +/// Processes an atomic construct with write clause. +void genAtomicWrite(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccAtomicWrite &atomicWrite, + mlir::Location loc) { + const Fortran::parser::AssignmentStmt &stmt = + std::get>( + atomicWrite.t) + .statement; + const Fortran::evaluate::Assignment &assign = *stmt.typedAssignment->v; + Fortran::lower::StatementContext stmtCtx; + // Get the value and address of atomic write operands. + mlir::Value rhsExpr = + fir::getBase(converter.genExprValue(assign.rhs, stmtCtx)); + mlir::Value lhsAddr = + fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx)); + genAtomicWriteStatement(converter, lhsAddr, rhsExpr, loc); +} + +/// Processes an atomic construct with read clause. +void genAtomicRead(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccAtomicRead &atomicRead, + mlir::Location loc) { + const auto &assignmentStmtExpr = std::get( + std::get>( + atomicRead.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>( + atomicRead.t) + .statement.t); + + Fortran::lower::StatementContext stmtCtx; + const Fortran::semantics::SomeExpr &fromExpr = + *Fortran::semantics::GetExpr(assignmentStmtExpr); + mlir::Type elementType = converter.genType(fromExpr); + mlir::Value fromAddress = + fir::getBase(converter.genExprAddr(fromExpr, stmtCtx)); + mlir::Value toAddress = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + genAtomicCaptureStatement(converter, fromAddress, toAddress, elementType, + loc); +} + +/// Processes an atomic construct with update clause. +void genAtomicUpdate(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccAtomicUpdate &atomicUpdate, + mlir::Location loc) { + const auto &assignmentStmtExpr = std::get( + std::get>( + atomicUpdate.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>( + atomicUpdate.t) + .statement.t); + + Fortran::lower::StatementContext stmtCtx; + mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( + *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); + genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable, + assignmentStmtExpr, loc); +} + +/// Processes an atomic construct with capture clause. +void genAtomicCapture(Fortran::lower::AbstractConverter &converter, + const Fortran::parser::AccAtomicCapture &atomicCapture, + mlir::Location loc) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + const Fortran::parser::AssignmentStmt &stmt1 = + std::get(atomicCapture.t) + .v.statement; + const Fortran::evaluate::Assignment &assign1 = *stmt1.typedAssignment->v; + const auto &stmt1Var{std::get(stmt1.t)}; + const auto &stmt1Expr{std::get(stmt1.t)}; + const Fortran::parser::AssignmentStmt &stmt2 = + std::get(atomicCapture.t) + .v.statement; + const Fortran::evaluate::Assignment &assign2 = *stmt2.typedAssignment->v; + const auto &stmt2Var{std::get(stmt2.t)}; + const auto &stmt2Expr{std::get(stmt2.t)}; + + // Pre-evaluate expressions to be used in the various operations inside + // `atomic.capture` since it is not desirable to have anything other than + // a `atomic.read`, `atomic.write`, or `atomic.update` operation + // inside `atomic.capture` + Fortran::lower::StatementContext stmtCtx; + // LHS evaluations are common to all combinations of `atomic.capture` + mlir::Value stmt1LHSArg = + fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx)); + mlir::Value stmt2LHSArg = + fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx)); + + // Type information used in generation of `atomic.update` operation + mlir::Type stmt1VarType = + fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType(); + mlir::Type stmt2VarType = + fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType(); + + mlir::Operation *atomicCaptureOp = nullptr; + atomicCaptureOp = firOpBuilder.create(loc); + + firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0))); + mlir::Block &block = atomicCaptureOp->getRegion(0).back(); + firOpBuilder.setInsertionPointToStart(&block); + if (Fortran::semantics::checkForSingleVariableOnRHS(stmt1)) { + if (Fortran::semantics::checkForSymbolMatch(stmt2)) { + // Atomic capture construct is of the form [capture-stmt, update-stmt] + const Fortran::semantics::SomeExpr &fromExpr = + *Fortran::semantics::GetExpr(stmt1Expr); + mlir::Type elementType = converter.genType(fromExpr); + genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg, + elementType, loc); + genAtomicUpdateStatement(converter, stmt2LHSArg, stmt2VarType, stmt2Var, + stmt2Expr, loc, atomicCaptureOp); + } else { + // Atomic capture construct is of the form [capture-stmt, write-stmt] + firOpBuilder.setInsertionPoint(atomicCaptureOp); + mlir::Value stmt2RHSArg = + fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx)); + firOpBuilder.setInsertionPointToStart(&block); + const Fortran::semantics::SomeExpr &fromExpr = + *Fortran::semantics::GetExpr(stmt1Expr); + mlir::Type elementType = converter.genType(fromExpr); + genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg, + elementType, loc); + genAtomicWriteStatement(converter, stmt2LHSArg, stmt2RHSArg, loc); + } + } else { + // Atomic capture construct is of the form [update-stmt, capture-stmt] + const Fortran::semantics::SomeExpr &fromExpr = + *Fortran::semantics::GetExpr(stmt2Expr); + mlir::Type elementType = converter.genType(fromExpr); + genAtomicUpdateStatement(converter, stmt1LHSArg, stmt1VarType, stmt1Var, + stmt1Expr, loc, atomicCaptureOp); + genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, elementType, + loc); + } + firOpBuilder.setInsertionPointToEnd(&block); + firOpBuilder.create(loc); + firOpBuilder.setInsertionPointToStart(&block); +} + template static void genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, @@ -4352,24 +4656,16 @@ genACC(Fortran::lower::AbstractConverter &converter, Fortran::common::visit( Fortran::common::visitors{ [&](const Fortran::parser::AccAtomicRead &atomicRead) { - Fortran::lower::genOmpAccAtomicRead(converter, atomicRead, - loc); + genAtomicRead(converter, atomicRead, loc); }, [&](const Fortran::parser::AccAtomicWrite &atomicWrite) { - Fortran::lower::genOmpAccAtomicWrite< - Fortran::parser::AccAtomicWrite, void>(converter, atomicWrite, - loc); + genAtomicWrite(converter, atomicWrite, loc); }, [&](const Fortran::parser::AccAtomicUpdate &atomicUpdate) { - Fortran::lower::genOmpAccAtomicUpdate< - Fortran::parser::AccAtomicUpdate, void>(converter, atomicUpdate, - loc); + genAtomicUpdate(converter, atomicUpdate, loc); }, [&](const Fortran::parser::AccAtomicCapture &atomicCapture) { - Fortran::lower::genOmpAccAtomicCapture< - Fortran::parser::AccAtomicCapture, void>(converter, - atomicCapture, loc); + genAtomicCapture(converter, atomicCapture, loc); }, }, atomicConstruct.u); diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 312557d5da07e..fdd85e94829f3 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2585,6 +2585,460 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } +//===----------------------------------------------------------------------===// +// Code generation for atomic operations +//===----------------------------------------------------------------------===// + +/// Populates \p hint and \p memoryOrder with appropriate clause information +/// if present on atomic construct. +static void genOmpAtomicHintAndMemoryOrderClauses( + lower::AbstractConverter &converter, + const parser::OmpAtomicClauseList &clauseList, mlir::IntegerAttr &hint, + mlir::omp::ClauseMemoryOrderKindAttr &memoryOrder) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + for (const parser::OmpAtomicClause &clause : clauseList.v) { + common::visit( + common::visitors{ + [&](const parser::OmpMemoryOrderClause &s) { + auto kind = common::visit( + common::visitors{ + [&](const parser::OmpClause::AcqRel &) { + return mlir::omp::ClauseMemoryOrderKind::Acq_rel; + }, + [&](const parser::OmpClause::Acquire &) { + return mlir::omp::ClauseMemoryOrderKind::Acquire; + }, + [&](const parser::OmpClause::Relaxed &) { + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + }, + [&](const parser::OmpClause::Release &) { + return mlir::omp::ClauseMemoryOrderKind::Release; + }, + [&](const parser::OmpClause::SeqCst &) { + return mlir::omp::ClauseMemoryOrderKind::Seq_cst; + }, + [&](auto &&) -> mlir::omp::ClauseMemoryOrderKind { + llvm_unreachable("Unexpected clause"); + }, + }, + s.v.u); + memoryOrder = mlir::omp::ClauseMemoryOrderKindAttr::get( + firOpBuilder.getContext(), kind); + }, + [&](const parser::OmpHintClause &s) { + const auto *expr = semantics::GetExpr(s.v); + uint64_t hintExprValue = *evaluate::ToInt64(*expr); + hint = firOpBuilder.getI64IntegerAttr(hintExprValue); + }, + [&](const parser::OmpFailClause &) {}, + }, + clause.u); + } +} + +static void processOmpAtomicTODO(mlir::Type elementType, mlir::Location loc) { + if (!elementType) + return; + assert(fir::isa_trivial(fir::unwrapRefType(elementType)) && + "is supported type for omp atomic"); +} + +/// Used to generate atomic.read operation which is created in existing +/// location set by builder. +static void genAtomicCaptureStatement( + lower::AbstractConverter &converter, mlir::Value fromAddress, + mlir::Value toAddress, + const parser::OmpAtomicClauseList *leftHandClauseList, + const parser::OmpAtomicClauseList *rightHandClauseList, + mlir::Type elementType, mlir::Location loc) { + // Generate `atomic.read` operation for atomic assigment statements + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + processOmpAtomicTODO(elementType, loc); + + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; + if (leftHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memoryOrder); + if (rightHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memoryOrder); + firOpBuilder.create(loc, fromAddress, toAddress, + mlir::TypeAttr::get(elementType), + hint, memoryOrder); +} + +/// Used to generate atomic.write operation which is created in existing +/// location set by builder. +static void genAtomicWriteStatement( + lower::AbstractConverter &converter, mlir::Value lhsAddr, + mlir::Value rhsExpr, const parser::OmpAtomicClauseList *leftHandClauseList, + const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc, + mlir::Value *evaluatedExprValue = nullptr) { + // Generate `atomic.write` operation for atomic assignment statements + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); + // Create a conversion outside the capture block. + auto insertionPoint = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointAfter(rhsExpr.getDefiningOp()); + rhsExpr = firOpBuilder.createConvert(loc, varType, rhsExpr); + firOpBuilder.restoreInsertionPoint(insertionPoint); + + processOmpAtomicTODO(varType, loc); + + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; + if (leftHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memoryOrder); + if (rightHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memoryOrder); + firOpBuilder.create(loc, lhsAddr, rhsExpr, hint, + memoryOrder); +} + +/// Used to generate atomic.update operation which is created in existing +/// location set by builder. +static void genAtomicUpdateStatement( + lower::AbstractConverter &converter, mlir::Value lhsAddr, + mlir::Type varType, const parser::Variable &assignmentStmtVariable, + const parser::Expr &assignmentStmtExpr, + const parser::OmpAtomicClauseList *leftHandClauseList, + const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc, + mlir::Operation *atomicCaptureOp = nullptr) { + // Generate `atomic.update` operation for atomic assignment statements + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + + // Create the omp.atomic.update or acc.atomic.update operation + // + // func.func @_QPsb() { + // %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"} + // %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"} + // %2 = fir.load %1 : !fir.ref + // omp.atomic.update %0 : !fir.ref { + // ^bb0(%arg0: i32): + // %3 = arith.addi %arg0, %2 : i32 + // omp.yield(%3 : i32) + // } + // return + // } + + auto getArgExpression = + [](std::list::const_iterator it) { + const auto &arg{std::get((*it).t)}; + const auto *parserExpr{ + std::get_if>(&arg.u)}; + return parserExpr; + }; + + // Lower any non atomic sub-expression before the atomic operation, and + // map its lowered value to the semantic representation. + lower::ExprToValueMap exprValueOverrides; + // Max and min intrinsics can have a list of Args. Hence we need a list + // of nonAtomicSubExprs to hoist. Currently, only the load is hoisted. + llvm::SmallVector nonAtomicSubExprs; + common::visit( + common::visitors{ + [&](const common::Indirection &funcRef) + -> void { + const auto &args{std::get>( + funcRef.value().v.t)}; + std::list::const_iterator beginIt = + args.begin(); + std::list::const_iterator endIt = args.end(); + const auto *exprFirst{getArgExpression(beginIt)}; + if (exprFirst && exprFirst->value().source == + assignmentStmtVariable.GetSource()) { + // Add everything except the first + beginIt++; + } else { + // Add everything except the last + endIt--; + } + std::list::const_iterator it; + for (it = beginIt; it != endIt; it++) { + const common::Indirection *expr = + getArgExpression(it); + if (expr) + nonAtomicSubExprs.push_back(semantics::GetExpr(*expr)); + } + }, + [&](const auto &op) -> void { + using T = std::decay_t; + if constexpr (std::is_base_of::value) { + const auto &exprLeft{std::get<0>(op.t)}; + const auto &exprRight{std::get<1>(op.t)}; + if (exprLeft.value().source == assignmentStmtVariable.GetSource()) + nonAtomicSubExprs.push_back(semantics::GetExpr(exprRight)); + else + nonAtomicSubExprs.push_back(semantics::GetExpr(exprLeft)); + } + }, + }, + assignmentStmtExpr.u); + lower::StatementContext nonAtomicStmtCtx; + if (!nonAtomicSubExprs.empty()) { + // Generate non atomic part before all the atomic operations. + auto insertionPoint = firOpBuilder.saveInsertionPoint(); + if (atomicCaptureOp) + firOpBuilder.setInsertionPoint(atomicCaptureOp); + mlir::Value nonAtomicVal; + for (auto *nonAtomicSubExpr : nonAtomicSubExprs) { + nonAtomicVal = fir::getBase(converter.genExprValue( + currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx)); + exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal); + } + if (atomicCaptureOp) + firOpBuilder.restoreInsertionPoint(insertionPoint); + } + + mlir::Operation *atomicUpdateOp = nullptr; + // If no hint clause is specified, the effect is as if + // hint(omp_sync_hint_none) had been specified. + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; + if (leftHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, + memoryOrder); + if (rightHandClauseList) + genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, + memoryOrder); + atomicUpdateOp = firOpBuilder.create( + currentLocation, lhsAddr, hint, memoryOrder); + + processOmpAtomicTODO(varType, loc); + + llvm::SmallVector varTys = {varType}; + llvm::SmallVector locs = {currentLocation}; + firOpBuilder.createBlock(&atomicUpdateOp->getRegion(0), {}, varTys, locs); + mlir::Value val = + fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0)); + + exprValueOverrides.try_emplace(semantics::GetExpr(assignmentStmtVariable), + val); + { + // statement context inside the atomic block. + converter.overrideExprValues(&exprValueOverrides); + lower::StatementContext atomicStmtCtx; + mlir::Value rhsExpr = fir::getBase(converter.genExprValue( + *semantics::GetExpr(assignmentStmtExpr), atomicStmtCtx)); + mlir::Value convertResult = + firOpBuilder.createConvert(currentLocation, varType, rhsExpr); + firOpBuilder.create(currentLocation, convertResult); + converter.resetExprOverrides(); + } + firOpBuilder.setInsertionPointAfter(atomicUpdateOp); +} + +/// Processes an atomic construct with write clause. +static void genAtomicWrite(lower::AbstractConverter &converter, + const parser::OmpAtomicWrite &atomicWrite, + mlir::Location loc) { + const parser::OmpAtomicClauseList *rightHandClauseList = nullptr; + const parser::OmpAtomicClauseList *leftHandClauseList = nullptr; + // Get the address of atomic read operands. + rightHandClauseList = &std::get<2>(atomicWrite.t); + leftHandClauseList = &std::get<0>(atomicWrite.t); + + const parser::AssignmentStmt &stmt = + std::get>(atomicWrite.t) + .statement; + const evaluate::Assignment &assign = *stmt.typedAssignment->v; + lower::StatementContext stmtCtx; + // Get the value and address of atomic write operands. + mlir::Value rhsExpr = + fir::getBase(converter.genExprValue(assign.rhs, stmtCtx)); + mlir::Value lhsAddr = + fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx)); + genAtomicWriteStatement(converter, lhsAddr, rhsExpr, leftHandClauseList, + rightHandClauseList, loc); +} + +/// Processes an atomic construct with read clause. +static void genAtomicRead(lower::AbstractConverter &converter, + const parser::OmpAtomicRead &atomicRead, + mlir::Location loc) { + const parser::OmpAtomicClauseList *rightHandClauseList = nullptr; + const parser::OmpAtomicClauseList *leftHandClauseList = nullptr; + // Get the address of atomic read operands. + rightHandClauseList = &std::get<2>(atomicRead.t); + leftHandClauseList = &std::get<0>(atomicRead.t); + + const auto &assignmentStmtExpr = std::get( + std::get>(atomicRead.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>(atomicRead.t) + .statement.t); + + lower::StatementContext stmtCtx; + const semantics::SomeExpr &fromExpr = *semantics::GetExpr(assignmentStmtExpr); + mlir::Type elementType = converter.genType(fromExpr); + mlir::Value fromAddress = + fir::getBase(converter.genExprAddr(fromExpr, stmtCtx)); + mlir::Value toAddress = fir::getBase(converter.genExprAddr( + *semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + genAtomicCaptureStatement(converter, fromAddress, toAddress, + leftHandClauseList, rightHandClauseList, + elementType, loc); +} + +/// Processes an atomic construct with update clause. +static void genAtomicUpdate(lower::AbstractConverter &converter, + const parser::OmpAtomicUpdate &atomicUpdate, + mlir::Location loc) { + const parser::OmpAtomicClauseList *rightHandClauseList = nullptr; + const parser::OmpAtomicClauseList *leftHandClauseList = nullptr; + // Get the address of atomic read operands. + rightHandClauseList = &std::get<2>(atomicUpdate.t); + leftHandClauseList = &std::get<0>(atomicUpdate.t); + + const auto &assignmentStmtExpr = std::get( + std::get>(atomicUpdate.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>(atomicUpdate.t) + .statement.t); + + lower::StatementContext stmtCtx; + mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( + *semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); + genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable, + assignmentStmtExpr, leftHandClauseList, + rightHandClauseList, loc); +} + +/// Processes an atomic construct with no clause - which implies update clause. +static void genOmpAtomic(lower::AbstractConverter &converter, + const parser::OmpAtomic &atomicConstruct, + mlir::Location loc) { + const parser::OmpAtomicClauseList &atomicClauseList = + std::get(atomicConstruct.t); + const auto &assignmentStmtExpr = std::get( + std::get>(atomicConstruct.t) + .statement.t); + const auto &assignmentStmtVariable = std::get( + std::get>(atomicConstruct.t) + .statement.t); + lower::StatementContext stmtCtx; + mlir::Value lhsAddr = fir::getBase(converter.genExprAddr( + *semantics::GetExpr(assignmentStmtVariable), stmtCtx)); + mlir::Type varType = fir::unwrapRefType(lhsAddr.getType()); + // If atomic-clause is not present on the construct, the behaviour is as if + // the update clause is specified (for both OpenMP and OpenACC). + genAtomicUpdateStatement(converter, lhsAddr, varType, assignmentStmtVariable, + assignmentStmtExpr, &atomicClauseList, nullptr, loc); +} + +/// Processes an atomic construct with capture clause. +static void genAtomicCapture(lower::AbstractConverter &converter, + const parser::OmpAtomicCapture &atomicCapture, + mlir::Location loc) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + const parser::AssignmentStmt &stmt1 = + std::get(atomicCapture.t).v.statement; + const evaluate::Assignment &assign1 = *stmt1.typedAssignment->v; + const auto &stmt1Var{std::get(stmt1.t)}; + const auto &stmt1Expr{std::get(stmt1.t)}; + const parser::AssignmentStmt &stmt2 = + std::get(atomicCapture.t).v.statement; + const evaluate::Assignment &assign2 = *stmt2.typedAssignment->v; + const auto &stmt2Var{std::get(stmt2.t)}; + const auto &stmt2Expr{std::get(stmt2.t)}; + + // Pre-evaluate expressions to be used in the various operations inside + // `atomic.capture` since it is not desirable to have anything other than + // a `atomic.read`, `atomic.write`, or `atomic.update` operation + // inside `atomic.capture` + lower::StatementContext stmtCtx; + // LHS evaluations are common to all combinations of `atomic.capture` + mlir::Value stmt1LHSArg = + fir::getBase(converter.genExprAddr(assign1.lhs, stmtCtx)); + mlir::Value stmt2LHSArg = + fir::getBase(converter.genExprAddr(assign2.lhs, stmtCtx)); + + // Type information used in generation of `atomic.update` operation + mlir::Type stmt1VarType = + fir::getBase(converter.genExprValue(assign1.lhs, stmtCtx)).getType(); + mlir::Type stmt2VarType = + fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType(); + + mlir::Operation *atomicCaptureOp = nullptr; + mlir::IntegerAttr hint = nullptr; + mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr; + const parser::OmpAtomicClauseList &rightHandClauseList = + std::get<2>(atomicCapture.t); + const parser::OmpAtomicClauseList &leftHandClauseList = + std::get<0>(atomicCapture.t); + genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, + memoryOrder); + genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, + memoryOrder); + atomicCaptureOp = + firOpBuilder.create(loc, hint, memoryOrder); + + firOpBuilder.createBlock(&(atomicCaptureOp->getRegion(0))); + mlir::Block &block = atomicCaptureOp->getRegion(0).back(); + firOpBuilder.setInsertionPointToStart(&block); + if (semantics::checkForSingleVariableOnRHS(stmt1)) { + if (semantics::checkForSymbolMatch(stmt2)) { + // Atomic capture construct is of the form [capture-stmt, update-stmt] + const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr); + mlir::Type elementType = converter.genType(fromExpr); + genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, elementType, + loc); + genAtomicUpdateStatement( + converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp); + } else { + // Atomic capture construct is of the form [capture-stmt, write-stmt] + firOpBuilder.setInsertionPoint(atomicCaptureOp); + mlir::Value stmt2RHSArg = + fir::getBase(converter.genExprValue(assign2.rhs, stmtCtx)); + firOpBuilder.setInsertionPointToStart(&block); + const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt1Expr); + mlir::Type elementType = converter.genType(fromExpr); + genAtomicCaptureStatement(converter, stmt2LHSArg, stmt1LHSArg, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, elementType, + loc); + genAtomicWriteStatement(converter, stmt2LHSArg, stmt2RHSArg, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, loc); + } + } else { + // Atomic capture construct is of the form [update-stmt, capture-stmt] + const semantics::SomeExpr &fromExpr = *semantics::GetExpr(stmt2Expr); + mlir::Type elementType = converter.genType(fromExpr); + genAtomicUpdateStatement( + converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, loc, atomicCaptureOp); + genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg, + /*leftHandClauseList=*/nullptr, + /*rightHandClauseList=*/nullptr, elementType, + loc); + } + firOpBuilder.setInsertionPointToEnd(&block); + firOpBuilder.create(loc); + firOpBuilder.setInsertionPointToStart(&block); +} + //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3476,32 +3930,23 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, common::visitors{ [&](const parser::OmpAtomicRead &atomicRead) { mlir::Location loc = converter.genLocation(atomicRead.source); - lower::genOmpAccAtomicRead( - converter, atomicRead, loc); + genAtomicRead(converter, atomicRead, loc); }, [&](const parser::OmpAtomicWrite &atomicWrite) { mlir::Location loc = converter.genLocation(atomicWrite.source); - lower::genOmpAccAtomicWrite( - converter, atomicWrite, loc); + genAtomicWrite(converter, atomicWrite, loc); }, [&](const parser::OmpAtomic &atomicConstruct) { mlir::Location loc = converter.genLocation(atomicConstruct.source); - lower::genOmpAtomic( - converter, atomicConstruct, loc); + genOmpAtomic(converter, atomicConstruct, loc); }, [&](const parser::OmpAtomicUpdate &atomicUpdate) { mlir::Location loc = converter.genLocation(atomicUpdate.source); - lower::genOmpAccAtomicUpdate( - converter, atomicUpdate, loc); + genAtomicUpdate(converter, atomicUpdate, loc); }, [&](const parser::OmpAtomicCapture &atomicCapture) { mlir::Location loc = converter.genLocation(atomicCapture.source); - lower::genOmpAccAtomicCapture( - converter, atomicCapture, loc); + genAtomicCapture(converter, atomicCapture, loc); }, [&](const parser::OmpAtomicCompare &atomicCompare) { mlir::Location loc = converter.genLocation(atomicCompare.source); diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index eff6d57995d2b..cdfd3e3223fa8 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -602,22 +602,20 @@ def OMP_Assume : Directive<"assume"> { ]; } def OMP_Atomic : Directive<"atomic"> { - let allowedClauses = [ - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - VersionedClause, - ]; let allowedOnceClauses = [ VersionedClause, VersionedClause, + VersionedClause, + VersionedClause, VersionedClause, VersionedClause, + VersionedClause, VersionedClause, VersionedClause, VersionedClause, + VersionedClause, VersionedClause, + VersionedClause, ]; let association = AS_Block; let category = CA_Executable; @@ -668,7 +666,7 @@ def OMP_CancellationPoint : Directive<"cancellation point"> { let category = CA_Executable; } def OMP_Critical : Directive<"critical"> { - let allowedClauses = [ + let allowedOnceClauses = [ VersionedClause, ]; let association = AS_Block;