Skip to content

Commit 32c9854

Browse files
[flang] Extend symbol update to ArrayAttrext in external-name-interop (#150061)
In the `external-name-interop` pass, when a symbol is changed, all the uses of the renamed symbols should also be updated. The update was only applied to the `SymbolRefAttr` type. With this change, the update will be applied to `ArrayAttr` containing elements of type `SymbolRefAttr`. --------- Co-authored-by: Delaram Talaashrafi <[email protected]>
1 parent eb817c7 commit 32c9854

File tree

2 files changed

+89
-11
lines changed

2 files changed

+89
-11
lines changed

flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ mangleExternalName(const std::pair<fir::NameUniquer::NameKind,
4141
appendUnderscore);
4242
}
4343

44+
/// Process a symbol reference and return the updated symbol reference if
45+
/// needed.
46+
std::optional<mlir::SymbolRefAttr>
47+
processSymbolRef(mlir::SymbolRefAttr symRef, mlir::Operation *nestedOp,
48+
const llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr>
49+
&remappings) {
50+
if (auto remap = remappings.find(symRef.getLeafReference());
51+
remap != remappings.end()) {
52+
mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr(remap->second);
53+
if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
54+
symAttr = mlir::SymbolRefAttr::get(
55+
symRef.getRootReference(), {mlir::FlatSymbolRefAttr(remap->second)});
56+
return symAttr;
57+
}
58+
return std::nullopt;
59+
}
60+
4461
namespace {
4562

4663
class ExternalNameConversionPass
@@ -97,21 +114,40 @@ void ExternalNameConversionPass::runOnOperation() {
97114

98115
// Update all uses of the functions and globals that have been renamed.
99116
op.walk([&remappings](mlir::Operation *nestedOp) {
100-
llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
117+
llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>>
118+
symRefUpdates;
119+
llvm::SmallVector<std::pair<mlir::StringAttr, mlir::ArrayAttr>>
120+
arrayUpdates;
101121
for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary())
102122
if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue())) {
103-
if (auto remap = remappings.find(symRef.getLeafReference());
104-
remap != remappings.end()) {
105-
mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr(remap->second);
106-
if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
107-
symAttr = mlir::SymbolRefAttr::get(
108-
symRef.getRootReference(),
109-
{mlir::FlatSymbolRefAttr(remap->second)});
110-
updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
111-
attr.getName(), symAttr});
123+
if (auto newSymRef = processSymbolRef(symRef, nestedOp, remappings))
124+
symRefUpdates.emplace_back(
125+
std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{attr.getName(),
126+
*newSymRef});
127+
} else if (auto arrayAttr =
128+
llvm::dyn_cast<mlir::ArrayAttr>(attr.getValue())) {
129+
llvm::SmallVector<mlir::Attribute> symbolRefs;
130+
for (auto element : arrayAttr) {
131+
if (!element) {
132+
symbolRefs.push_back(element);
133+
continue;
134+
}
135+
auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(element);
136+
std::optional<mlir::SymbolRefAttr> updatedSymRef;
137+
if (symRef)
138+
updatedSymRef = processSymbolRef(symRef, nestedOp, remappings);
139+
if (!symRef || !updatedSymRef)
140+
symbolRefs.push_back(element);
141+
else
142+
symbolRefs.push_back(*updatedSymRef);
112143
}
144+
arrayUpdates.push_back(std::make_pair(
145+
attr.getName(),
146+
mlir::ArrayAttr::get(nestedOp->getContext(), symbolRefs)));
113147
}
114-
for (auto update : updates)
148+
for (auto update : symRefUpdates)
149+
nestedOp->setAttr(update.first, update.second);
150+
for (auto update : arrayUpdates)
115151
nestedOp->setAttr(update.first, update.second);
116152
});
117153
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Test fir.do_concurrent.loop operation with array of symbol reference attributes
2+
// This test demonstrates operations that have ArrayAttr containing SymbolRefAttr elements
3+
4+
// RUN: fir-opt %s --external-name-interop | fir-opt | FileCheck %s
5+
6+
// Define reduction operations that will be referenced in the symbol array
7+
func.func @_QPadd_reduction_i32_init(%arg0: i32, %arg1: !fir.ref<i32>) {
8+
%0 = arith.constant 0 : i32
9+
fir.store %0 to %arg1 : !fir.ref<i32>
10+
return
11+
}
12+
13+
func.func @_QPadd_reduction_i32_combiner(%arg0: i32, %arg1: i32) -> i32 {
14+
%0 = arith.addi %arg0, %arg1 : i32
15+
return %0 : i32
16+
}
17+
18+
// Define a local privatizer that will be referenced in local_syms
19+
func.func @_QPlocal_var_privatizer(%arg0: !fir.ref<i32>) -> !fir.ref<i32> {
20+
return %arg0 : !fir.ref<i32>
21+
}
22+
23+
// Test function demonstrating both local_syms and reduce_syms arrays
24+
func.func @_QPtest_symbol_arrays(%i_lb: index, %i_ub: index, %i_st: index) {
25+
%local_var = fir.alloca i32
26+
%sum = fir.alloca i32
27+
28+
fir.do_concurrent {
29+
%i = fir.alloca i32
30+
fir.do_concurrent.loop (%i_iv) = (%i_lb) to (%i_ub) step (%i_st)
31+
local(@_QPlocal_var_privatizer %local_var -> %local_arg : !fir.ref<i32>)
32+
reduce(@_QPadd_reduction_i32_init #fir.reduce_attr<add> %sum -> %sum_arg : !fir.ref<i32>) {
33+
%0 = fir.convert %i_iv : (index) -> i32
34+
fir.store %0 to %i : !fir.ref<i32>
35+
}
36+
}
37+
return
38+
}
39+
40+
// CHECK: local(@local_var_privatizer_
41+
// CHECK: reduce(@add_reduction_i32_init_
42+

0 commit comments

Comments
 (0)