2929#include " flang/Lower/PFTBuilder.h"
3030#include " flang/Lower/StatementContext.h"
3131#include " flang/Lower/Support/Utils.h"
32+ #include " flang/Optimizer/Builder/Complex.h"
3233#include " flang/Optimizer/Builder/DirectivesCommon.h"
3334#include " flang/Optimizer/Builder/HLFIRTools.h"
3435#include " flang/Optimizer/Dialect/FIRType.h"
@@ -103,6 +104,61 @@ static void processOmpAtomicTODO(mlir::Type elementType,
103104 }
104105}
105106
107+ // / Emits an implicit cast for atomic statements
108+ static void emitImplicitCast (Fortran::lower::AbstractConverter &converter,
109+ mlir::Location loc, mlir::Value &fromAddress,
110+ mlir::Value &toAddress, mlir::Type &elementType) {
111+ if (fromAddress.getType () == toAddress.getType ())
112+ return ;
113+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
114+ mlir::Value alloca = builder.create <fir::AllocaOp>(
115+ loc, fir::unwrapRefType (toAddress.getType ()));
116+ mlir::Value loadedVal = builder.create <fir::LoadOp>(loc, fromAddress);
117+ mlir::Type toType = fir::unwrapRefType (toAddress.getType ());
118+ mlir::Type fromType = fir::unwrapRefType (fromAddress.getType ());
119+ if (!fir::isa_complex (toType) && !fir::isa_complex (fromType)) {
120+ loadedVal = builder.create <fir::ConvertOp>(
121+ loc, fir::unwrapRefType (toAddress.getType ()), loadedVal);
122+ builder.create <fir::StoreOp>(loc, loadedVal, alloca);
123+ } else if (!fir::isa_complex (toType) && fir::isa_complex (fromType)) {
124+ loadedVal = builder.create <fir::ExtractValueOp>(
125+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (),
126+ loadedVal,
127+ builder.getArrayAttr (
128+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
129+ loadedVal = builder.create <fir::ConvertOp>(loc, toType, loadedVal);
130+ builder.create <fir::StoreOp>(loc, loadedVal, alloca);
131+ } else if (fir::isa_complex (toType) && fir::isa_complex (fromType)) {
132+ mlir::Value firstComp = builder.create <fir::ExtractValueOp>(
133+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (),
134+ loadedVal,
135+ builder.getArrayAttr (
136+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
137+ mlir::Value secondComp = builder.create <fir::ExtractValueOp>(
138+ loc, mlir::cast<mlir::ComplexType>(fromType).getElementType (),
139+ loadedVal,
140+ builder.getArrayAttr (
141+ builder.getIntegerAttr (builder.getIndexType (), 1 )));
142+ firstComp = builder.create <fir::ConvertOp>(
143+ loc, mlir::cast<mlir::ComplexType>(toType).getElementType (), firstComp);
144+ secondComp = builder.create <fir::ConvertOp>(
145+ loc, mlir::cast<mlir::ComplexType>(toType).getElementType (),
146+ secondComp);
147+ auto undef = builder.create <fir::UndefOp>(loc, toType);
148+ mlir::Value pair1 = builder.create <fir::InsertValueOp>(
149+ loc, toType, undef, firstComp,
150+ builder.getArrayAttr (
151+ builder.getIntegerAttr (builder.getIndexType (), 0 )));
152+ mlir::Value pair = builder.create <fir::InsertValueOp>(
153+ loc, toType, pair1, secondComp,
154+ builder.getArrayAttr (
155+ builder.getIntegerAttr (builder.getIndexType (), 1 )));
156+ builder.create <fir::StoreOp>(loc, pair, alloca);
157+ }
158+ fromAddress = alloca;
159+ elementType = fir::unwrapRefType (toAddress.getType ());
160+ }
161+
106162// / Used to generate atomic.read operation which is created in existing
107163// / location set by builder.
108164template <typename AtomicListT>
@@ -386,6 +442,7 @@ void genOmpAccAtomicRead(Fortran::lower::AbstractConverter &converter,
386442 fir::getBase (converter.genExprAddr (fromExpr, stmtCtx));
387443 mlir::Value toAddress = fir::getBase (converter.genExprAddr (
388444 *Fortran::semantics::GetExpr (assignmentStmtVariable), stmtCtx));
445+ emitImplicitCast (converter, loc, fromAddress, toAddress, elementType);
389446 genOmpAccAtomicCaptureStatement (converter, fromAddress, toAddress,
390447 leftHandClauseList, rightHandClauseList,
391448 elementType, loc);
@@ -481,6 +538,30 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
481538 mlir::Type stmt2VarType =
482539 fir::getBase (converter.genExprValue (assign2.lhs , stmtCtx)).getType ();
483540
541+ // Checks helpful in constructing the `atomic.capture` region
542+ bool hasSingleVariable =
543+ Fortran::semantics::checkForSingleVariableOnRHS (stmt1);
544+ bool hasSymMatch = Fortran::semantics::checkForSymbolMatch (stmt2);
545+
546+ // Implicit casts
547+ mlir::Type captureStmtElemTy;
548+ if (hasSingleVariable) {
549+ if (hasSymMatch) {
550+ // Atomic capture construct is of the form [capture-stmt, update-stmt]
551+ // FIXME: Emit an implicit cast if there is a type mismatch
552+ } else {
553+ // Atomic capture construct is of the form [capture-stmt, write-stmt]
554+ const Fortran::semantics::SomeExpr &fromExpr =
555+ *Fortran::semantics::GetExpr (stmt1Expr);
556+ captureStmtElemTy = converter.genType (fromExpr);
557+ emitImplicitCast (converter, loc, stmt2LHSArg, stmt1LHSArg,
558+ captureStmtElemTy);
559+ }
560+ } else {
561+ // Atomic capture construct is of the form [update-stmt, capture-stmt]
562+ // FIXME: Emit an implicit cast if there is a type mismatch
563+ }
564+
484565 mlir::Operation *atomicCaptureOp = nullptr ;
485566 if constexpr (std::is_same<AtomicListT,
486567 Fortran::parser::OmpAtomicClauseList>()) {
@@ -501,8 +582,8 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
501582 firOpBuilder.createBlock (&(atomicCaptureOp->getRegion (0 )));
502583 mlir::Block &block = atomicCaptureOp->getRegion (0 ).back ();
503584 firOpBuilder.setInsertionPointToStart (&block);
504- if (Fortran::semantics::checkForSingleVariableOnRHS (stmt1) ) {
505- if (Fortran::semantics::checkForSymbolMatch (stmt2) ) {
585+ if (hasSingleVariable ) {
586+ if (hasSymMatch ) {
506587 // Atomic capture construct is of the form [capture-stmt, update-stmt]
507588 const Fortran::semantics::SomeExpr &fromExpr =
508589 *Fortran::semantics::GetExpr (stmt1Expr);
@@ -521,13 +602,10 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
521602 mlir::Value stmt2RHSArg =
522603 fir::getBase (converter.genExprValue (assign2.rhs , stmtCtx));
523604 firOpBuilder.setInsertionPointToStart (&block);
524- const Fortran::semantics::SomeExpr &fromExpr =
525- *Fortran::semantics::GetExpr (stmt1Expr);
526- mlir::Type elementType = converter.genType (fromExpr);
527605 genOmpAccAtomicCaptureStatement<AtomicListT>(
528606 converter, stmt2LHSArg, stmt1LHSArg,
529607 /* leftHandClauseList=*/ nullptr ,
530- /* rightHandClauseList=*/ nullptr , elementType , loc);
608+ /* rightHandClauseList=*/ nullptr , captureStmtElemTy , loc);
531609 genOmpAccAtomicWriteStatement<AtomicListT>(
532610 converter, stmt2LHSArg, stmt2RHSArg,
533611 /* leftHandClauseList=*/ nullptr ,
0 commit comments