Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def CufImplicitDeviceGlobal :
def CUFAddConstructor : Pass<"cuf-add-constructor", "mlir::ModuleOp"> {
let summary = "Add constructor to register CUDA Fortran allocators";
let dependentDialects = [
"mlir::func::FuncDialect"
"cuf::CUFDialect", "mlir::func::FuncDialect"
];
}

Expand Down
20 changes: 18 additions & 2 deletions flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"
Expand All @@ -23,6 +24,8 @@ namespace fir {

namespace {

static constexpr llvm::StringRef cudaModName{"cuda_device_mod"};

static constexpr llvm::StringRef cudaFortranCtorName{
"__cudaFortranConstructor"};

Expand All @@ -31,6 +34,7 @@ struct CUFAddConstructor

void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mlir::SymbolTable symTab(mod);
mlir::OpBuilder builder{mod.getBodyRegion()};
builder.setInsertionPointToEnd(mod.getBody());
mlir::Location loc = mod.getLoc();
Expand All @@ -48,13 +52,25 @@ struct CUFAddConstructor
mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
builder.setInsertionPointToEnd(mod.getBody());

// Create the constructor function that cal CUFRegisterAllocator.
builder.setInsertionPointToEnd(mod.getBody());
// Create the constructor function that call CUFRegisterAllocator.
auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
funcTy);
func.setLinkage(mlir::LLVM::Linkage::Internal);
builder.setInsertionPointToStart(func.addEntryBlock(builder));
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);

// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
if (gpuMod) {
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaModName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
builder.create<cuf::RegisterKernelOp>(loc, kernelName);
}
}
}
builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});

// Create the llvm.global_ctor with the function.
Expand Down
8 changes: 2 additions & 6 deletions flang/test/Fir/CUDA/cuda-register-func.fir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fir-opt %s | FileCheck %s
// RUN: fir-opt --cuf-add-constructor %s | FileCheck %s

module attributes {gpu.container_module} {
gpu.module @cuda_device_mod {
Expand All @@ -9,12 +9,8 @@ module attributes {gpu.container_module} {
gpu.return
}
}
llvm.func internal @__cudaFortranConstructor() {
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
llvm.return
}
}

// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
Loading