diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td index f55f3e8a4466d..ccf9969e73a8e 100644 --- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td +++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td @@ -200,7 +200,7 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface, let arguments = (ins SymbolRefAttr:$callee, I32:$grid_x, I32:$grid_y, I32:$grid_z, I32:$block_x, I32:$block_y, I32:$block_z, - Optional:$bytes, Optional:$stream, + Optional:$bytes, Optional:$stream, Variadic:$args, OptionalAttr:$arg_attrs, OptionalAttr:$res_attrs); @@ -237,6 +237,8 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface, *this, getNbNoArgOperand(), getArgs().size() - 1); } }]; + + let hasVerifier = 1; } def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments, diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 31f2650917781..f28778ce6c1c9 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -589,7 +589,7 @@ Fortran::lower::genCallOpAndResult( mlir::Value stream; // stream is optional. if (caller.getCallDescription().chevrons().size() > 3) - stream = fir::getBase(converter.genExprValue( + stream = fir::getBase(converter.genExprAddr( caller.getCallDescription().chevrons()[3], stmtCtx)); builder.create( diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp index ce197d48d4860..2c6d22f6f6c7d 100644 --- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp +++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp @@ -139,6 +139,24 @@ llvm::LogicalResult cuf::DeallocateOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// KernelLaunchOp +//===----------------------------------------------------------------------===// + +template +static llvm::LogicalResult checkStreamType(OpTy op) { + if (!op.getStream()) + return mlir::success(); + auto refTy = mlir::dyn_cast(op.getStream().getType()); + if (!refTy.getEleTy().isInteger(64)) + return op.emitOpError("stream is expected to be a i64 reference"); + return mlir::success(); +} + +llvm::LogicalResult cuf::KernelLaunchOp::verify() { + return checkStreamType(*this); +} + //===----------------------------------------------------------------------===// // KernelOp //===----------------------------------------------------------------------===// @@ -324,10 +342,7 @@ void cuf::SharedMemoryOp::build( //===----------------------------------------------------------------------===// llvm::LogicalResult cuf::StreamCastOp::verify() { - auto refTy = mlir::dyn_cast(getStream().getType()); - if (!refTy.getEleTy().isInteger(64)) - return emitOpError("stream is expected to be a i64 reference"); - return mlir::success(); + return checkStreamType(*this); } // Tablegen operators diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index caa59c6c17d0f..77364cb837c3c 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -879,8 +879,13 @@ struct CUFLaunchOpConversion gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY); gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ); } - if (op.getStream()) - gpuLaunchOp.getAsyncObjectMutable().assign(op.getStream()); + if (op.getStream()) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(gpuLaunchOp); + mlir::Value stream = + rewriter.create(loc, op.getStream()); + gpuLaunchOp.getAsyncDependenciesMutable().append(stream); + } if (procAttr) gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr); rewriter.replaceOp(op, gpuLaunchOp); @@ -933,6 +938,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase { /*forceUnifiedTBAATree=*/false, *dl); target.addLegalDialect(); + target.addLegalOp(); cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab, patterns); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir index 621772efff415..319991546d3fe 100644 --- a/flang/test/Fir/CUDA/cuda-launch.fir +++ b/flang/test/Fir/CUDA/cuda-launch.fir @@ -146,8 +146,7 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e %1:2 = hlfir.declare %0 {uniq_name = "_QMtest_callFhostEstream"} : (!fir.ref) -> (!fir.ref, !fir.ref) %c1_i32 = arith.constant 1 : i32 %c0_i32 = arith.constant 0 : i32 - %2 = fir.load %1#0 : !fir.ref - cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c0_i32, %2 : i64>>>() + cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c0_i32, %1#0 : !fir.ref>>>() return } } @@ -155,5 +154,5 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e // CHECK-LABEL: func.func @_QQmain() // CHECK: %[[STREAM:.*]] = fir.alloca i64 {bindc_name = "stream", uniq_name = "_QMtest_callFhostEstream"} // CHECK: %[[DECL_STREAM:.*]]:2 = hlfir.declare %[[STREAM]] {uniq_name = "_QMtest_callFhostEstream"} : (!fir.ref) -> (!fir.ref, !fir.ref) -// CHECK: %[[STREAM_LOADED:.*]] = fir.load %[[DECL_STREAM]]#0 : !fir.ref -// CHECK: gpu.launch_func <%[[STREAM_LOADED]] : i64> @cuda_device_mod::@_QMdevptrPtest +// CHECK: %[[TOKEN:.*]] = cuf.stream_cast %[[DECL_STREAM]]#0 : +// CHECK: gpu.launch_func [%[[TOKEN]]] @cuda_device_mod::@_QMdevptrPtest diff --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf index d66d2811f7a8b..71e594e4742ec 100644 --- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf +++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf @@ -45,8 +45,8 @@ contains call dev_kernel0<<<10, 20, 2>>>() ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() - call dev_kernel0<<<10, 20, 2, 0>>>() -! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>() + call dev_kernel0<<<10, 20, 2, 0_8>>>() +! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %{{.*}} : !fir.ref>>>() call dev_kernel1<<<1, 32>>>(a) ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}}) : (!fir.ref) @@ -55,7 +55,7 @@ contains ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}}) call dev_kernel1<<<*,32,0,stream>>>(a) -! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : i64>>>(%{{.*}}) : (!fir.ref) +! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : !fir.ref>>>(%{{.*}}) : (!fir.ref) end