@@ -41,6 +41,23 @@ mangleExternalName(const std::pair<fir::NameUniquer::NameKind,
41
41
appendUnderscore);
42
42
}
43
43
44
+ // / Process a symbol reference and return the updated symbol reference if
45
+ // / needed.
46
+ std::optional<mlir::SymbolRefAttr>
47
+ processSymbolRef (mlir::SymbolRefAttr symRef, mlir::Operation *nestedOp,
48
+ const llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr>
49
+ &remappings) {
50
+ if (auto remap = remappings.find (symRef.getLeafReference ());
51
+ remap != remappings.end ()) {
52
+ mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr (remap->second );
53
+ if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
54
+ symAttr = mlir::SymbolRefAttr::get (
55
+ symRef.getRootReference (), {mlir::FlatSymbolRefAttr (remap->second )});
56
+ return symAttr;
57
+ }
58
+ return std::nullopt;
59
+ }
60
+
44
61
namespace {
45
62
46
63
class ExternalNameConversionPass
@@ -97,21 +114,40 @@ void ExternalNameConversionPass::runOnOperation() {
97
114
98
115
// Update all uses of the functions and globals that have been renamed.
99
116
op.walk ([&remappings](mlir::Operation *nestedOp) {
100
- llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
117
+ llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>>
118
+ symRefUpdates;
119
+ llvm::SmallVector<std::pair<mlir::StringAttr, mlir::ArrayAttr>>
120
+ arrayUpdates;
101
121
for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary ())
102
122
if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue ())) {
103
- if (auto remap = remappings.find (symRef.getLeafReference ());
104
- remap != remappings.end ()) {
105
- mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr (remap->second );
106
- if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
107
- symAttr = mlir::SymbolRefAttr::get (
108
- symRef.getRootReference (),
109
- {mlir::FlatSymbolRefAttr (remap->second )});
110
- updates.emplace_back (std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
111
- attr.getName (), symAttr});
123
+ if (auto newSymRef = processSymbolRef (symRef, nestedOp, remappings))
124
+ symRefUpdates.emplace_back (
125
+ std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{attr.getName (),
126
+ *newSymRef});
127
+ } else if (auto arrayAttr =
128
+ llvm::dyn_cast<mlir::ArrayAttr>(attr.getValue ())) {
129
+ llvm::SmallVector<mlir::Attribute> symbolRefs;
130
+ for (auto element : arrayAttr) {
131
+ if (!element) {
132
+ symbolRefs.push_back (element);
133
+ continue ;
134
+ }
135
+ auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(element);
136
+ std::optional<mlir::SymbolRefAttr> updatedSymRef;
137
+ if (symRef)
138
+ updatedSymRef = processSymbolRef (symRef, nestedOp, remappings);
139
+ if (!symRef || !updatedSymRef)
140
+ symbolRefs.push_back (element);
141
+ else
142
+ symbolRefs.push_back (*updatedSymRef);
112
143
}
144
+ arrayUpdates.push_back (std::make_pair (
145
+ attr.getName (),
146
+ mlir::ArrayAttr::get (nestedOp->getContext (), symbolRefs)));
113
147
}
114
- for (auto update : updates)
148
+ for (auto update : symRefUpdates)
149
+ nestedOp->setAttr (update.first , update.second );
150
+ for (auto update : arrayUpdates)
115
151
nestedOp->setAttr (update.first , update.second );
116
152
});
117
153
}
0 commit comments