Skip to content

Commit 537359e

Browse files
committed
[flang][cuda] Update CompilerGeneratedNames pass to work on gpu module
1 parent 99c2e3b commit 537359e

File tree

3 files changed

+68
-30
lines changed

3 files changed

+68
-30
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,10 +1247,10 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
12471247

12481248
/// Get the address of the type descriptor global variable that was created by
12491249
/// lowering for derived type \p recType.
1250-
mlir::Value getTypeDescriptor(mlir::ModuleOp mod,
1251-
mlir::ConversionPatternRewriter &rewriter,
1252-
mlir::Location loc,
1253-
fir::RecordType recType) const {
1250+
template <typename ModOpTy>
1251+
mlir::Value
1252+
getTypeDescriptor(ModOpTy mod, mlir::ConversionPatternRewriter &rewriter,
1253+
mlir::Location loc, fir::RecordType recType) const {
12541254
std::string name =
12551255
this->options.typeDescriptorsRenamedForAssembly
12561256
? fir::NameUniquer::getTypeDescriptorAssemblyName(recType.getName())
@@ -1275,7 +1275,8 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
12751275
return rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy);
12761276
}
12771277

1278-
mlir::Value populateDescriptor(mlir::Location loc, mlir::ModuleOp mod,
1278+
template <typename ModOpTy>
1279+
mlir::Value populateDescriptor(mlir::Location loc, ModOpTy mod,
12791280
fir::BaseBoxType boxTy, mlir::Type inputType,
12801281
mlir::ConversionPatternRewriter &rewriter,
12811282
unsigned rank, mlir::Value eleSize,
@@ -1414,10 +1415,16 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
14141415
extraField =
14151416
this->getExtraFromBox(loc, sourceBoxTyPair, sourceBox, rewriter);
14161417
}
1417-
auto mod = box->template getParentOfType<mlir::ModuleOp>();
1418-
mlir::Value descriptor =
1419-
populateDescriptor(loc, mod, boxTy, inputType, rewriter, rank, eleSize,
1420-
cfiTy, typeDesc, allocatorIdx, extraField);
1418+
1419+
mlir::Value descriptor;
1420+
if (auto gpuMod = box->template getParentOfType<mlir::gpu::GPUModuleOp>())
1421+
descriptor = populateDescriptor(loc, gpuMod, boxTy, inputType, rewriter,
1422+
rank, eleSize, cfiTy, typeDesc,
1423+
allocatorIdx, extraField);
1424+
else if (auto mod = box->template getParentOfType<mlir::ModuleOp>())
1425+
descriptor = populateDescriptor(loc, mod, boxTy, inputType, rewriter,
1426+
rank, eleSize, cfiTy, typeDesc,
1427+
allocatorIdx, extraField);
14211428

14221429
return {boxTy, descriptor, eleSize};
14231430
}
@@ -1460,11 +1467,17 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
14601467
mlir::Value extraField =
14611468
this->getExtraFromBox(loc, inputBoxTyPair, loweredBox, rewriter);
14621469

1463-
auto mod = box->template getParentOfType<mlir::ModuleOp>();
1464-
mlir::Value descriptor =
1465-
populateDescriptor(loc, mod, boxTy, box.getBox().getType(), rewriter,
1466-
rank, eleSize, cfiTy, typeDesc,
1467-
/*allocatorIdx=*/kDefaultAllocator, extraField);
1470+
mlir::Value descriptor;
1471+
if (auto gpuMod = box->template getParentOfType<mlir::gpu::GPUModuleOp>())
1472+
descriptor =
1473+
populateDescriptor(loc, gpuMod, boxTy, box.getBox().getType(),
1474+
rewriter, rank, eleSize, cfiTy, typeDesc,
1475+
/*allocatorIdx=*/kDefaultAllocator, extraField);
1476+
else if (auto mod = box->template getParentOfType<mlir::ModuleOp>())
1477+
descriptor =
1478+
populateDescriptor(loc, mod, boxTy, box.getBox().getType(), rewriter,
1479+
rank, eleSize, cfiTy, typeDesc,
1480+
/*allocatorIdx=*/kDefaultAllocator, extraField);
14681481

14691482
return {boxTy, descriptor, eleSize};
14701483
}

flang/lib/Optimizer/Transforms/CompilerGeneratedNames.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
1212
#include "flang/Optimizer/Support/InternalNames.h"
1313
#include "flang/Optimizer/Transforms/Passes.h"
14+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1415
#include "mlir/IR/Attributes.h"
1516
#include "mlir/IR/SymbolTable.h"
1617
#include "mlir/Pass/Pass.h"
@@ -42,24 +43,31 @@ void CompilerGeneratedNamesConversionPass::runOnOperation() {
4243
auto *context = &getContext();
4344

4445
llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
45-
for (auto &funcOrGlobal : op->getRegion(0).front()) {
46-
if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) ||
47-
llvm::isa<fir::GlobalOp>(funcOrGlobal)) {
48-
auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
49-
mlir::SymbolTable::getSymbolAttrName());
50-
auto deconstructedName = fir::NameUniquer::deconstruct(symName);
51-
if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED &&
52-
!fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
53-
std::string newName =
54-
fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str());
55-
if (newName != symName) {
56-
auto newAttr = mlir::StringAttr::get(context, newName);
57-
mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
58-
auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
59-
remappings.try_emplace(symName, newSymRef);
60-
}
46+
47+
auto processOp = [&](mlir::Operation &op) {
48+
auto symName = op.getAttrOfType<mlir::StringAttr>(
49+
mlir::SymbolTable::getSymbolAttrName());
50+
auto deconstructedName = fir::NameUniquer::deconstruct(symName);
51+
if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED &&
52+
!fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
53+
std::string newName =
54+
fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str());
55+
if (newName != symName) {
56+
auto newAttr = mlir::StringAttr::get(context, newName);
57+
mlir::SymbolTable::setSymbolName(&op, newAttr);
58+
auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
59+
remappings.try_emplace(symName, newSymRef);
6160
}
6261
}
62+
};
63+
for (auto &op : op->getRegion(0).front()) {
64+
if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op))
65+
processOp(op);
66+
else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(&op))
67+
for (auto &op : gpuMod->getRegion(0).front())
68+
if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op) ||
69+
llvm::isa<mlir::gpu::GPUFuncOp>(op))
70+
processOp(op);
6371
}
6472

6573
if (remappings.empty())
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: fir-opt --split-input-file --compiler-generated-names --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu type-descriptors-renamed-for-assembly=true" %s | FileCheck %s
2+
3+
module @mod1 attributes {gpu.container} {
4+
gpu.module @gpu1 {
5+
fir.global linkonce @_QMtest_dinitE.dt.tseq constant : i8
6+
7+
func.func @embox1(%arg0: !fir.ref<!fir.type<_QMtest_dinitTtseq{i:i32}>>) {
8+
%0 = fir.embox %arg0() : (!fir.ref<!fir.type<_QMtest_dinitTtseq{i:i32}>>) -> !fir.box<!fir.type<_QMtest_dinitTtseq{i:i32}>>
9+
return
10+
}
11+
}
12+
}
13+
14+
// CHECK-LABEL: gpu.module @gpu1
15+
// CHECK: llvm.mlir.global linkonce constant @_QMtest_dinitEXdtXtseq
16+
// CHECK: llvm.mlir.addressof @_QMtest_dinitEXdtXtseq : !llvm.ptr
17+

0 commit comments

Comments
 (0)