Skip to content

Commit ee5aa72

Browse files
authored
Merge pull request #1194 from schweitzpgi/ch-avc5
When there is a call to a function which host associations,
2 parents e9694d3 + 4b4b84f commit ee5aa72

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

flang/include/flang/Optimizer/Dialect/FIROpsSupport.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@ mlir::FuncOp createFuncOp(mlir::Location loc, mlir::ModuleOp module,
5656
llvm::StringRef name, mlir::FunctionType type,
5757
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
5858

59-
/// Get or create a GlobalOp in a module.
60-
fir::GlobalOp createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
61-
llvm::StringRef name, mlir::Type type,
62-
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});
63-
6459
/// Attribute to mark Fortran entities with the CONTIGUOUS attribute.
6560
static constexpr llvm::StringRef getContiguousAttrName() {
6661
return "fir.contiguous";
@@ -82,6 +77,10 @@ static constexpr llvm::StringRef getHostAssocAttrName() {
8277
return "fir.host_assoc";
8378
}
8479

80+
/// Does the function, \p func, have a host-associations tuple argument?
81+
/// Some internal procedures may have access to host procedure variables.
82+
bool hasHostAssociationArgument(mlir::FuncOp func);
83+
8584
/// Tell if \p value is:
8685
/// - a function argument that has attribute \p attributeName
8786
/// - or, the result of fir.alloca/fir.allocamem op that has attribute \p

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,15 +3221,13 @@ mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
32213221
return result;
32223222
}
32233223

3224-
fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
3225-
StringRef name, mlir::Type type,
3226-
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
3227-
if (auto g = module.lookupSymbol<fir::GlobalOp>(name))
3228-
return g;
3229-
mlir::OpBuilder modBuilder(module.getBodyRegion());
3230-
auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs);
3231-
result.setVisibility(mlir::SymbolTable::Visibility::Private);
3232-
return result;
3224+
bool fir::hasHostAssociationArgument(mlir::FuncOp func) {
3225+
if (auto allArgAttrs = func.getAllArgAttrs())
3226+
for (auto attr : allArgAttrs)
3227+
if (auto dict = attr.template dyn_cast_or_null<mlir::DictionaryAttr>())
3228+
if (dict.get(fir::getHostAssocAttrName()))
3229+
return true;
3230+
return false;
32333231
}
32343232

32353233
bool fir::valueHasFirAttribute(mlir::Value value,

flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "flang/Optimizer/Builder/FIRBuilder.h"
1414
#include "flang/Optimizer/Builder/Factory.h"
1515
#include "flang/Optimizer/Dialect/FIRDialect.h"
16+
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
1617
#include "flang/Optimizer/Support/FIRContext.h"
1718
#include "flang/Optimizer/Transforms/Passes.h"
1819
#include "mlir/Dialect/SCF/SCF.h"
@@ -224,6 +225,11 @@ class ReachCollector {
224225
for (auto *user : op->getResult(0).getUsers())
225226
followUsers(user);
226227

228+
if (mlir::isa<fir::CallOp>(op)) {
229+
LLVM_DEBUG(llvm::dbgs() << "add " << *op << " to reachable set\n");
230+
reach.push_back(op);
231+
}
232+
227233
for (auto u : op->getOperands())
228234
collectArrayMentionFrom(u);
229235
}
@@ -539,6 +545,22 @@ static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
539545
return conflictOnLoad(reach, st) || conflictOnMerge(mentions);
540546
}
541547

548+
// Assume that any call to a function that uses host-associations will be
549+
// modifying the output array.
550+
static bool
551+
conservativeCallConflict(llvm::ArrayRef<mlir::Operation *> reaches) {
552+
return llvm::any_of(reaches, [](mlir::Operation *op) {
553+
if (auto call = mlir::dyn_cast<fir::CallOp>(op))
554+
if (auto callee =
555+
call.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>()) {
556+
auto module = op->getParentOfType<mlir::ModuleOp>();
557+
return fir::hasHostAssociationArgument(
558+
module.lookupSymbol<mlir::FuncOp>(callee));
559+
}
560+
return false;
561+
});
562+
}
563+
542564
/// Constructor of the array copy analysis.
543565
/// This performs the analysis and saves the intermediate results.
544566
void ArrayCopyAnalysis::construct(mlir::MutableArrayRef<mlir::Region> regions) {
@@ -550,12 +572,13 @@ void ArrayCopyAnalysis::construct(mlir::MutableArrayRef<mlir::Region> regions) {
550572
if (auto st = mlir::dyn_cast<ArrayMergeStoreOp>(op)) {
551573
llvm::SmallVector<Operation *> values;
552574
ReachCollector::reachingValues(values, st.sequence());
575+
auto callConflict = conservativeCallConflict(values);
553576
llvm::SmallVector<Operation *> mentions;
554577
arrayMentions(mentions,
555578
mlir::cast<ArrayLoadOp>(st.original().getDefiningOp()));
556579
auto conflict = conflictDetected(values, mentions, st);
557580
auto refConflict = conflictOnReference(mentions);
558-
if (conflict || refConflict) {
581+
if (callConflict || conflict || refConflict) {
559582
LLVM_DEBUG(llvm::dbgs()
560583
<< "CONFLICT: copies required for " << st << '\n'
561584
<< " adding conflicts on: " << op << " and "

0 commit comments

Comments
 (0)