13
13
#include " flang/Optimizer/Builder/FIRBuilder.h"
14
14
#include " flang/Optimizer/Builder/Factory.h"
15
15
#include " flang/Optimizer/Dialect/FIRDialect.h"
16
+ #include " flang/Optimizer/Dialect/FIROpsSupport.h"
16
17
#include " flang/Optimizer/Support/FIRContext.h"
17
18
#include " flang/Optimizer/Transforms/Passes.h"
18
19
#include " mlir/Dialect/SCF/SCF.h"
@@ -224,6 +225,11 @@ class ReachCollector {
224
225
for (auto *user : op->getResult (0 ).getUsers ())
225
226
followUsers (user);
226
227
228
+ if (mlir::isa<fir::CallOp>(op)) {
229
+ LLVM_DEBUG (llvm::dbgs () << " add " << *op << " to reachable set\n " );
230
+ reach.push_back (op);
231
+ }
232
+
227
233
for (auto u : op->getOperands ())
228
234
collectArrayMentionFrom (u);
229
235
}
@@ -539,6 +545,22 @@ static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
539
545
return conflictOnLoad (reach, st) || conflictOnMerge (mentions);
540
546
}
541
547
548
+ // Assume that any call to a function that uses host-associations will be
549
+ // modifying the output array.
550
+ static bool
551
+ conservativeCallConflict (llvm::ArrayRef<mlir::Operation *> reaches) {
552
+ return llvm::any_of (reaches, [](mlir::Operation *op) {
553
+ if (auto call = mlir::dyn_cast<fir::CallOp>(op))
554
+ if (auto callee =
555
+ call.getCallableForCallee ().dyn_cast <mlir::SymbolRefAttr>()) {
556
+ auto module = op->getParentOfType <mlir::ModuleOp>();
557
+ return fir::hasHostAssociationArgument (
558
+ module .lookupSymbol <mlir::FuncOp>(callee));
559
+ }
560
+ return false ;
561
+ });
562
+ }
563
+
542
564
// / Constructor of the array copy analysis.
543
565
// / This performs the analysis and saves the intermediate results.
544
566
void ArrayCopyAnalysis::construct (mlir::MutableArrayRef<mlir::Region> regions) {
@@ -550,12 +572,13 @@ void ArrayCopyAnalysis::construct(mlir::MutableArrayRef<mlir::Region> regions) {
550
572
if (auto st = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
551
573
llvm::SmallVector<Operation *> values;
552
574
ReachCollector::reachingValues (values, st.sequence ());
575
+ auto callConflict = conservativeCallConflict (values);
553
576
llvm::SmallVector<Operation *> mentions;
554
577
arrayMentions (mentions,
555
578
mlir::cast<ArrayLoadOp>(st.original ().getDefiningOp ()));
556
579
auto conflict = conflictDetected (values, mentions, st);
557
580
auto refConflict = conflictOnReference (mentions);
558
- if (conflict || refConflict) {
581
+ if (callConflict || conflict || refConflict) {
559
582
LLVM_DEBUG (llvm::dbgs ()
560
583
<< " CONFLICT: copies required for " << st << ' \n '
561
584
<< " adding conflicts on: " << op << " and "
0 commit comments