From f3652719d1fc37bc1b46b5756c3c9b1842a179d9 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Tue, 3 Dec 2024 21:51:23 -0800 Subject: [PATCH] [flang][cuda] Run target rewrite in gpu.module --- flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 6 ++++++ flang/test/Fir/CUDA/cuda-target-rewrite.mlir | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 flang/test/Fir/CUDA/cuda-target-rewrite.mlir diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index ae6e7ce798d99..1b86d5241704b 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -27,6 +27,7 @@ #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Support/DataLayout.h" #include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" @@ -720,6 +721,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { convertSignature(fn); } + + for (auto gpuMod : mod.getOps()) + for (auto fn : gpuMod.getOps()) + convertSignature(fn); + return mlir::success(); } diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir new file mode 100644 index 0000000000000..5ba41b0e8afbe --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir @@ -0,0 +1,16 @@ +// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s + +gpu.module @testmod { + gpu.func @_QPvcpowdk(%arg0: !fir.ref> {cuf.data_attr = #cuf.cuda, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc} { + %0 = fir.alloca i64 + %1 = fir.load %0 : !fir.ref + %2 = fir.load %arg0 : !fir.ref> + %3 = fir.call @_FortranAzpowk(%2, %1) fastmath : (complex, i64) -> complex + gpu.return + } + func.func private @_FortranAzpowk(complex, i64) -> complex attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime} +} + +// CHECK-LABEL: gpu.func @_QPvcpowdk +// CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple +// CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}