Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions flang/include/flang/Optimizer/Transforms/CUFCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===-- CUFCommon.h -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
#define FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/BuiltinOps.h"

static constexpr llvm::StringRef cudaDeviceModuleName = "cuda_device_mod";

namespace cuf {

/// Retrieve or create the CUDA Fortran GPU module in the given \p mod.
mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
mlir::SymbolTable &symTab);

} // namespace cuf

#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
CompilerGeneratedNames.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
CUFCommon.cpp
CUFAddConstructor.cpp
CUFDeviceGlobal.cpp
CUFOpConversion.cpp
Expand Down
7 changes: 3 additions & 4 deletions flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand All @@ -24,8 +25,6 @@ namespace fir {

namespace {

static constexpr llvm::StringRef cudaModName{"cuda_device_mod"};

static constexpr llvm::StringRef cudaFortranCtorName{
"__cudaFortranConstructor"};

Expand Down Expand Up @@ -60,15 +59,15 @@ struct CUFAddConstructor
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);

// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
if (gpuMod) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaModName),
builder.getStringAttr(cudaDeviceModuleName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
Expand Down
31 changes: 31 additions & 0 deletions flang/lib/Optimizer/Transforms/CUFCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===-- CUFCommon.cpp - Shared functions between passes ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"

/// Retrieve or create the CUDA Fortran GPU module in the give in \p mod.
mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
mlir::SymbolTable &symTab) {
if (auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName))
return gpuMod;

auto *ctx = mod.getContext();
mod->setAttr(mlir::gpu::GPUDialect::getContainerModuleAttrName(),
mlir::UnitAttr::get(ctx));

mlir::OpBuilder builder(ctx);
auto gpuMod = builder.create<mlir::gpu::GPUModuleOp>(mod.getLoc(),
cudaDeviceModuleName);
llvm::SmallVector<mlir::Attribute> targets;
targets.push_back(mlir::NVVM::NVVMTargetAttr::get(ctx));
gpuMod.setTargetsAttr(builder.getArrayAttr(targets));
mlir::Block::iterator insertPt(mod.getBodyRegion().front().end());
symTab.insert(gpuMod, insertPt);
return gpuMod;
}
Loading