diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h index 4132db672e394..1edded090f8ce 100644 --- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h +++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h @@ -12,6 +12,7 @@ #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" #include "flang/Optimizer/Dialect/CUF/CUFDialect.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/OpDefinition.h" #define GET_OP_CLASSES diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td index 98d1ef529738c..d34a8af0394a4 100644 --- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td +++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td @@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td" include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td" include "flang/Optimizer/Dialect/FIRTypes.td" include "flang/Optimizer/Dialect/FIRAttr.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/BuiltinAttributes.td" @@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments, let hasVerifier = 1; } +def cuf_RegisterModuleOp : cuf_Op<"register_module", []> { + let summary = "Register a CUDA module"; + + let arguments = (ins + SymbolRefAttr:$name + ); + + let assemblyFormat = [{ + $name attr-dict `->` type($modulePtr) + }]; + + let results = (outs LLVM_AnyPointer:$modulePtr); +} + def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> { let summary = "Register a CUDA kernel"; let arguments = (ins - SymbolRefAttr:$name + SymbolRefAttr:$name, + LLVM_AnyPointer:$modulePtr ); let assemblyFormat = [{ - $name attr-dict + $name `(` $modulePtr `:` type($modulePtr) `)`attr-dict }]; let hasVerifier = 1; diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp index 3db24226e7504..f260437e71041 100644 --- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp +++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp @@ -62,12 +62,15 @@ struct CUFAddConstructor // Register kernels auto gpuMod = symTab.lookup(cudaModName); if (gpuMod) { + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx); + auto registeredMod = builder.create( + loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName())); for (auto func : gpuMod.getOps()) { if (func.isKernel()) { auto kernelName = mlir::SymbolRefAttr::get( builder.getStringAttr(cudaModName), {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())}); - builder.create(loc, kernelName); + builder.create(loc, kernelName, registeredMod); } } } diff --git a/flang/test/Fir/CUDA/cuda-register-func.fir b/flang/test/Fir/CUDA/cuda-register-func.fir index 277475f0883dc..6b0cbfd3aca63 100644 --- a/flang/test/Fir/CUDA/cuda-register-func.fir +++ b/flang/test/Fir/CUDA/cuda-register-func.fir @@ -12,5 +12,6 @@ module attributes {gpu.container_module} { } // CHECK-LABEL: llvm.func internal @__cudaFortranConstructor() -// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1 -// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2 +// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr +// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr) +// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr) diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir index 8a1eb48576832..a3b9be3ee8223 100644 --- a/flang/test/Fir/cuf-invalid.fir +++ b/flang/test/Fir/cuf-invalid.fir @@ -135,8 +135,9 @@ module attributes {gpu.container_module} { } } llvm.func internal @__cudaFortranConstructor() { + %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr // expected-error@+1{{'cuf.register_kernel' op only kernel gpu.func can be registered}} - cuf.register_kernel @cuda_device_mod::@_QPsub_device1 + cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr) llvm.return } } @@ -150,8 +151,9 @@ module attributes {gpu.container_module} { } } llvm.func internal @__cudaFortranConstructor() { + %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr // expected-error@+1{{'cuf.register_kernel' op device function not found}} - cuf.register_kernel @cuda_device_mod::@_QPsub_device2 + cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr) llvm.return } } @@ -160,8 +162,9 @@ module attributes {gpu.container_module} { module attributes {gpu.container_module} { llvm.func internal @__cudaFortranConstructor() { + %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr // expected-error@+1{{'cuf.register_kernel' op gpu module not found}} - cuf.register_kernel @cuda_device_mod::@_QPsub_device1 + cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr) llvm.return } } @@ -170,8 +173,9 @@ module attributes {gpu.container_module} { module attributes {gpu.container_module} { llvm.func internal @__cudaFortranConstructor() { + %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr // expected-error@+1{{'cuf.register_kernel' op expect a module and a kernel name}} - cuf.register_kernel @_QPsub_device1 + cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr) llvm.return } } @@ -185,8 +189,9 @@ module attributes {gpu.container_module} { } } llvm.func internal @__cudaFortranConstructor() { + %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr // expected-error@+1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}} - cuf.register_kernel @cuda_device_mod::@_QPsub_device1 + cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr) llvm.return } }