Skip to content

Commit 3acf856

Browse files
authored
Adding CUFCommon.{h,cpp} for CUF utilities (#113740)
1 parent 242ccd2 commit 3acf856

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===-- CUFCommon.h -------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
10+
#define FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
11+
12+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
15+
static constexpr llvm::StringRef cudaDeviceModuleName = "cuda_device_mod";
16+
17+
namespace cuf {
18+
19+
/// Retrieve or create the CUDA Fortran GPU module in the given \p mod.
20+
mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
21+
mlir::SymbolTable &symTab);
22+
23+
} // namespace cuf
24+
25+
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
99
CompilerGeneratedNames.cpp
1010
ConstantArgumentGlobalisation.cpp
1111
ControlFlowConverter.cpp
12+
CUFCommon.cpp
1213
CUFAddConstructor.cpp
1314
CUFDeviceGlobal.cpp
1415
CUFOpConversion.cpp

flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Optimizer/Dialect/FIRAttr.h"
1212
#include "flang/Optimizer/Dialect/FIRDialect.h"
1313
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
14+
#include "flang/Optimizer/Transforms/CUFCommon.h"
1415
#include "flang/Runtime/entry-names.h"
1516
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1617
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -24,8 +25,6 @@ namespace fir {
2425

2526
namespace {
2627

27-
static constexpr llvm::StringRef cudaModName{"cuda_device_mod"};
28-
2928
static constexpr llvm::StringRef cudaFortranCtorName{
3029
"__cudaFortranConstructor"};
3130

@@ -60,15 +59,15 @@ struct CUFAddConstructor
6059
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
6160

6261
// Register kernels
63-
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
62+
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
6463
if (gpuMod) {
6564
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
6665
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
6766
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
6867
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
6968
if (func.isKernel()) {
7069
auto kernelName = mlir::SymbolRefAttr::get(
71-
builder.getStringAttr(cudaModName),
70+
builder.getStringAttr(cudaDeviceModuleName),
7271
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
7372
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
7473
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===-- CUFCommon.cpp - Shared functions between passes ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Transforms/CUFCommon.h"
10+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
11+
12+
/// Retrieve or create the CUDA Fortran GPU module in the give in \p mod.
13+
mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
14+
mlir::SymbolTable &symTab) {
15+
if (auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName))
16+
return gpuMod;
17+
18+
auto *ctx = mod.getContext();
19+
mod->setAttr(mlir::gpu::GPUDialect::getContainerModuleAttrName(),
20+
mlir::UnitAttr::get(ctx));
21+
22+
mlir::OpBuilder builder(ctx);
23+
auto gpuMod = builder.create<mlir::gpu::GPUModuleOp>(mod.getLoc(),
24+
cudaDeviceModuleName);
25+
llvm::SmallVector<mlir::Attribute> targets;
26+
targets.push_back(mlir::NVVM::NVVMTargetAttr::get(ctx));
27+
gpuMod.setTargetsAttr(builder.getArrayAttr(targets));
28+
mlir::Block::iterator insertPt(mod.getBodyRegion().front().end());
29+
symTab.insert(gpuMod, insertPt);
30+
return gpuMod;
31+
}

0 commit comments

Comments
 (0)