Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
Expand Down
20 changes: 18 additions & 2 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"

Expand Down Expand Up @@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
let hasVerifier = 1;
}

def cuf_RegisterModuleOp : cuf_Op<"register_module", []> {
let summary = "Register a CUDA module";

let arguments = (ins
SymbolRefAttr:$name
);

let assemblyFormat = [{
$name attr-dict `->` type($modulePtr)
}];

let results = (outs LLVM_AnyPointer:$modulePtr);
}

def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
let summary = "Register a CUDA kernel";

let arguments = (ins
SymbolRefAttr:$name
SymbolRefAttr:$name,
LLVM_AnyPointer:$modulePtr
);

let assemblyFormat = [{
$name attr-dict
$name `(` $modulePtr `:` type($modulePtr) `)`attr-dict
}];

let hasVerifier = 1;
Expand Down
29 changes: 29 additions & 0 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for GPU dialect to LLVM IR translation.
//
//===----------------------------------------------------------------------===//

#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_

namespace mlir {
class DialectRegistry;
class MLIRContext;
} // namespace mlir

namespace cuf {

/// Register the CUF dialect and the translation from it to the LLVM IR in
/// the given registry.
void registerCUFDialectTranslation(mlir::DialectRegistry &registry);

} // namespace cuf

#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Support/InitFIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H

#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/Conversion/Passes.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry &registry,
if (addFIRInlinerInterface)
addFIRInlinerExtension(registry);
addFIRToLLVMIRExtension(registry);
cuf::registerCUFDialectTranslation(registry);
}

inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
Expand Down
28 changes: 28 additions & 0 deletions flang/include/flang/Runtime/CUDA/registration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_

#include "flang/Runtime/entry-names.h"
#include <cstddef>

namespace Fortran::runtime::cuda {

extern "C" {

/// Register a CUDA module.
void *RTDECL(CUFRegisterModule)(void *data);

/// Register a device function.
void RTDECL(CUFRegisterFunction)(void **module, const char *fct);

} // extern "C"

} // namespace Fortran::runtime::cuda
#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(Attributes)
add_flang_library(CUFDialect
CUFDialect.cpp
CUFOps.cpp
CUFToLLVMIRTranslation.cpp

DEPENDS
MLIRIR
Expand Down
104 changes: 104 additions & 0 deletions flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR CUF dialect and LLVM IR.
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

namespace {

LogicalResult registerModule(cuf::RegisterModuleOp op,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
std::string binaryIdentifier =
op.getName().getLeafReference().str() + "_bin_cst";
llvm::Module *module = moduleTranslation.getLLVMModule();
llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
if (!binary)
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;

llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterModule),
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
auto *handle = builder.CreateCall(fct, {binary});
moduleTranslation.mapValue(op->getResults().front()) = handle;
return mlir::success();
}

llvm::Value *getOrCreateFunctionName(llvm::Module *module,
llvm::IRBuilderBase &builder,
llvm::StringRef moduleName,
llvm::StringRef kernelName) {
std::string globalName =
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));

if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
return gv;

return builder.CreateGlobalString(kernelName, globalName);
}

LogicalResult registerKernel(cuf::RegisterKernelOp op,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *module = moduleTranslation.getLLVMModule();
llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterFunction),
llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
false));
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
builder.CreateCall(
fct, {modulePtr, getOrCreateFunctionName(module, builder,
op.getKernelModuleName().str(),
op.getKernelName().str())});
return mlir::success();
}

class CUFDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;

LogicalResult
convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const override {
return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
.Case([&](cuf::RegisterModuleOp op) {
return registerModule(op, builder, moduleTranslation);
})
.Case([&](cuf::RegisterKernelOp op) {
return registerKernel(op, builder, moduleTranslation);
})
.Default([&](Operation *op) {
return op->emitError("unsupported GPU operation: ") << op->getName();
});
}
};

} // namespace

void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
registry.insert<cuf::CUFDialect>();
registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
});
}
5 changes: 4 additions & 1 deletion flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ struct CUFAddConstructor
// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
if (gpuMod) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaModName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
builder.create<cuf::RegisterKernelOp>(loc, kernelName);
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CufOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down
1 change: 1 addition & 0 deletions flang/runtime/CUDA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
allocatable.cpp
descriptor.cpp
memory.cpp
registration.cpp
)

if (BUILD_SHARED_LIBS)
Expand Down
31 changes: 31 additions & 0 deletions flang/runtime/CUDA/registration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===-- runtime/CUDA/registration.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Runtime/CUDA/registration.h"

#include "cuda_runtime.h"

namespace Fortran::runtime::cuda {

extern "C" {

extern void **__cudaRegisterFatBinary(void *data);
extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);

void *RTDECL(CUFRegisterModule)(void *data) {
return __cudaRegisterFatBinary(data);
}

void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
__cudaRegisterFunction(module, fct, (char *)fct, fct, -1, (uint3 *)0,
(uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
}
}
} // namespace Fortran::runtime::cuda
5 changes: 3 additions & 2 deletions flang/test/Fir/CUDA/cuda-register-func.fir
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ module attributes {gpu.container_module} {
}

// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr)
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr)
15 changes: 10 additions & 5 deletions flang/test/Fir/cuf-invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -150,8 +151,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op device function not found}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -160,8 +162,9 @@ module attributes {gpu.container_module} {

module attributes {gpu.container_module} {
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op gpu module not found}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -170,8 +173,9 @@ module attributes {gpu.container_module} {

module attributes {gpu.container_module} {
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op expect a module and a kernel name}}
cuf.register_kernel @_QPsub_device1
cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Expand All @@ -185,8 +189,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
%0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error@+1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
Loading