From 764aebc93d6ebb1bcd013855edbe77c472a48504 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 6 Feb 2025 16:45:31 -0800 Subject: [PATCH] [flang][cuda] Lower syncwrape to NVVM intrinsic --- .../include/flang/Optimizer/Builder/IntrinsicCall.h | 1 + flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 13 +++++++++++++ flang/module/cudadevice.f90 | 2 +- flang/test/Lower/CUDA/cuda-device-proc.cuf | 6 +++--- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h index 32010ae83641e..47e8a77fa6aec 100644 --- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h @@ -406,6 +406,7 @@ struct IntrinsicLibrary { mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef); mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef); mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef); + void genSyncWarp(llvm::ArrayRef); fir::ExtendedValue genSystem(std::optional, mlir::ArrayRef args); void genSystemClock(llvm::ArrayRef); diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index a6a77dd58677b..9b684520ec078 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{ {"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false}, {"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false}, {"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false}, + {"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false}, {"system", &I::genSystem, {{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}}, @@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType, return builder.create(loc, funcOp, args).getResult(0); } +// SYNCWARP +void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef args) { + assert(args.size() == 1); + constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync"; + mlir::Value mask = fir::getBase(args[0]); + mlir::FunctionType funcType = + mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {}); + auto funcOp = builder.createFunction(loc, funcName, funcType); + llvm::SmallVector argsList{mask}; + builder.create(loc, funcOp, argsList); +} + // SYSTEM fir::ExtendedValue IntrinsicLibrary::genSystem(std::optional resultType, diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index 47526bccd98fe..45b9f2c838638 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -49,7 +49,7 @@ attributes(device) integer function syncthreads_or(value) public :: syncthreads_or interface - attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp') + attributes(device) subroutine syncwarp(mask) integer, value :: mask end subroutine end interface diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index ec825263474c1..17a6a1d965640 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -47,7 +47,7 @@ end ! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc} ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath : () -> () -! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs fastmath : (i32) -> () +! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath : (i32) -> () ! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath : () -> () ! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath : () -> () ! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath : () -> () @@ -102,13 +102,13 @@ end ! CHECK-LABEL: func.func @_QPhost1() ! CHECK: cuf.kernel ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath : () -> () -! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs fastmath : (i32) -> () +! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath : (i32) -> () ! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath : (i32) -> i32 ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath : (i32) -> i32 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath : (i32) -> i32 ! CHECK: func.func private @llvm.nvvm.barrier0() -! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs} +! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32) ! CHECK: func.func private @llvm.nvvm.membar.gl() ! CHECK: func.func private @llvm.nvvm.membar.cta() ! CHECK: func.func private @llvm.nvvm.membar.sys()