@@ -41,6 +41,23 @@ mangleExternalName(const std::pair<fir::NameUniquer::NameKind,
4141 appendUnderscore);
4242}
4343
44+ // / Process a symbol reference and return the updated symbol reference if
45+ // / needed.
46+ std::optional<mlir::SymbolRefAttr>
47+ processSymbolRef (mlir::SymbolRefAttr symRef, mlir::Operation *nestedOp,
48+ const llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr>
49+ &remappings) {
50+ if (auto remap = remappings.find (symRef.getLeafReference ());
51+ remap != remappings.end ()) {
52+ mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr (remap->second );
53+ if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
54+ symAttr = mlir::SymbolRefAttr::get (
55+ symRef.getRootReference (), {mlir::FlatSymbolRefAttr (remap->second )});
56+ return symAttr;
57+ }
58+ return std::nullopt ;
59+ }
60+
4461namespace {
4562
4663class ExternalNameConversionPass
@@ -97,21 +114,40 @@ void ExternalNameConversionPass::runOnOperation() {
97114
98115 // Update all uses of the functions and globals that have been renamed.
99116 op.walk ([&remappings](mlir::Operation *nestedOp) {
100- llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
117+ llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>>
118+ symRefUpdates;
119+ llvm::SmallVector<std::pair<mlir::StringAttr, mlir::ArrayAttr>>
120+ arrayUpdates;
101121 for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary ())
102122 if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue ())) {
103- if (auto remap = remappings.find (symRef.getLeafReference ());
104- remap != remappings.end ()) {
105- mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr (remap->second );
106- if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
107- symAttr = mlir::SymbolRefAttr::get (
108- symRef.getRootReference (),
109- {mlir::FlatSymbolRefAttr (remap->second )});
110- updates.emplace_back (std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
111- attr.getName (), symAttr});
123+ if (auto newSymRef = processSymbolRef (symRef, nestedOp, remappings))
124+ symRefUpdates.emplace_back (
125+ std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{attr.getName (),
126+ *newSymRef});
127+ } else if (auto arrayAttr =
128+ llvm::dyn_cast<mlir::ArrayAttr>(attr.getValue ())) {
129+ llvm::SmallVector<mlir::Attribute> symbolRefs;
130+ for (auto element : arrayAttr) {
131+ if (!element) {
132+ symbolRefs.push_back (element);
133+ continue ;
134+ }
135+ auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(element);
136+ std::optional<mlir::SymbolRefAttr> updatedSymRef;
137+ if (symRef)
138+ updatedSymRef = processSymbolRef (symRef, nestedOp, remappings);
139+ if (!symRef || !updatedSymRef)
140+ symbolRefs.push_back (element);
141+ else
142+ symbolRefs.push_back (*updatedSymRef);
112143 }
144+ arrayUpdates.push_back (std::make_pair (
145+ attr.getName (),
146+ mlir::ArrayAttr::get (nestedOp->getContext (), symbolRefs)));
113147 }
114- for (auto update : updates)
148+ for (auto update : symRefUpdates)
149+ nestedOp->setAttr (update.first , update.second );
150+ for (auto update : arrayUpdates)
115151 nestedOp->setAttr (update.first , update.second );
116152 });
117153}
0 commit comments