1313#include " flang/Optimizer/Builder/FIRBuilder.h"
1414#include " flang/Optimizer/Builder/Factory.h"
1515#include " flang/Optimizer/Dialect/FIRDialect.h"
16+ #include " flang/Optimizer/Dialect/FIROpsSupport.h"
1617#include " flang/Optimizer/Support/FIRContext.h"
1718#include " flang/Optimizer/Transforms/Passes.h"
1819#include " mlir/Dialect/SCF/SCF.h"
@@ -224,6 +225,11 @@ class ReachCollector {
224225 for (auto *user : op->getResult (0 ).getUsers ())
225226 followUsers (user);
226227
228+ if (mlir::isa<fir::CallOp>(op)) {
229+ LLVM_DEBUG (llvm::dbgs () << " add " << *op << " to reachable set\n " );
230+ reach.push_back (op);
231+ }
232+
227233 for (auto u : op->getOperands ())
228234 collectArrayMentionFrom (u);
229235 }
@@ -539,6 +545,22 @@ static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
539545 return conflictOnLoad (reach, st) || conflictOnMerge (mentions);
540546}
541547
548+ // Assume that any call to a function that uses host-associations will be
549+ // modifying the output array.
550+ static bool
551+ conservativeCallConflict (llvm::ArrayRef<mlir::Operation *> reaches) {
552+ return llvm::any_of (reaches, [](mlir::Operation *op) {
553+ if (auto call = mlir::dyn_cast<fir::CallOp>(op))
554+ if (auto callee =
555+ call.getCallableForCallee ().dyn_cast <mlir::SymbolRefAttr>()) {
556+ auto module = op->getParentOfType <mlir::ModuleOp>();
557+ return fir::hasHostAssociationArgument (
558+ module .lookupSymbol <mlir::FuncOp>(callee));
559+ }
560+ return false ;
561+ });
562+ }
563+
542564// / Constructor of the array copy analysis.
543565// / This performs the analysis and saves the intermediate results.
544566void ArrayCopyAnalysis::construct (mlir::MutableArrayRef<mlir::Region> regions) {
@@ -550,12 +572,13 @@ void ArrayCopyAnalysis::construct(mlir::MutableArrayRef<mlir::Region> regions) {
550572 if (auto st = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
551573 llvm::SmallVector<Operation *> values;
552574 ReachCollector::reachingValues (values, st.sequence ());
575+ auto callConflict = conservativeCallConflict (values);
553576 llvm::SmallVector<Operation *> mentions;
554577 arrayMentions (mentions,
555578 mlir::cast<ArrayLoadOp>(st.original ().getDefiningOp ()));
556579 auto conflict = conflictDetected (values, mentions, st);
557580 auto refConflict = conflictOnReference (mentions);
558- if (conflict || refConflict) {
581+ if (callConflict || conflict || refConflict) {
559582 LLVM_DEBUG (llvm::dbgs ()
560583 << " CONFLICT: copies required for " << st << ' \n '
561584 << " adding conflicts on: " << op << " and "
0 commit comments