Skip to content

Commit 1763752

Browse files
clementvalDebadri Basak
authored andcommitted
[flang][cuda] Add interfaces and lowering for barrier_try_wait(_sleep) (llvm#165316)
As described in the programming guide: https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations
1 parent 9ce6628 commit 1763752

File tree

4 files changed

+105
-6
lines changed

4 files changed

+105
-6
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ struct IntrinsicLibrary {
211211
mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
212212
mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>);
213213
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
214+
mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
215+
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
214216
fir::ExtendedValue genBesselJn(mlir::Type,
215217
llvm::ArrayRef<fir::ExtendedValue>);
216218
fir::ExtendedValue genBesselYn(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5151
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
5252
#include "mlir/Dialect/Math/IR/Math.h"
53+
#include "mlir/Dialect/SCF/IR/SCF.h"
5354
#include "mlir/Dialect/Vector/IR/VectorOps.h"
5455
#include "llvm/Support/CommandLine.h"
5556
#include "llvm/Support/Debug.h"
@@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{
358359
&I::genBarrierInit,
359360
{{{"barrier", asAddr}, {"count", asValue}}},
360361
/*isElemental=*/false},
362+
{"barrier_try_wait",
363+
&I::genBarrierTryWait,
364+
{{{"barrier", asAddr}, {"token", asValue}}},
365+
/*isElemental=*/false},
366+
{"barrier_try_wait_sleep",
367+
&I::genBarrierTryWaitSleep,
368+
{{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}},
369+
/*isElemental=*/false},
361370
{"bessel_jn",
362371
&I::genBesselJn,
363372
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -3282,6 +3291,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
32823291
mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
32833292
}
32843293

3294+
// BARRIER_TRY_WAIT (CUDA)
3295+
mlir::Value
3296+
IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType,
3297+
llvm::ArrayRef<mlir::Value> args) {
3298+
assert(args.size() == 2);
3299+
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
3300+
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
3301+
fir::StoreOp::create(builder, loc, zero, res);
3302+
mlir::Value ns =
3303+
builder.createIntegerConstant(loc, builder.getI32Type(), 1000000);
3304+
mlir::Value load = fir::LoadOp::create(builder, loc, res);
3305+
auto whileOp = mlir::scf::WhileOp::create(
3306+
builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load});
3307+
mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore());
3308+
mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc);
3309+
builder.setInsertionPointToStart(beforeBlock);
3310+
mlir::Value condition = mlir::arith::CmpIOp::create(
3311+
builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero);
3312+
mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg);
3313+
mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter());
3314+
afterBlock->addArgument(resultType, loc);
3315+
builder.setInsertionPointToStart(afterBlock);
3316+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
3317+
auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
3318+
mlir::Value ret =
3319+
mlir::NVVM::InlinePtxOp::create(
3320+
builder, loc, {resultType}, {barrier, args[1], ns}, {},
3321+
".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
3322+
"selp.b32 %0, 1, 0, p;",
3323+
{})
3324+
.getResult(0);
3325+
mlir::scf::YieldOp::create(builder, loc, ret);
3326+
builder.setInsertionPointAfter(whileOp);
3327+
return whileOp.getResult(0);
3328+
}
3329+
3330+
// BARRIER_TRY_WAIT_SLEEP (CUDA)
3331+
mlir::Value
3332+
IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
3333+
llvm::ArrayRef<mlir::Value> args) {
3334+
assert(args.size() == 3);
3335+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
3336+
auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
3337+
return mlir::NVVM::InlinePtxOp::create(
3338+
builder, loc, {resultType}, {barrier, args[1], args[2]}, {},
3339+
".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
3340+
"selp.b32 %0, 1, 0, p;",
3341+
{})
3342+
.getResult(0);
3343+
}
3344+
32853345
// BESSEL_JN
32863346
fir::ExtendedValue
32873347
IntrinsicLibrary::genBesselJn(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,22 +1998,37 @@ attributes(device,host) logical function on_device() bind(c)
19981998

19991999
! TMA Operations
20002000

2001+
interface barrier_arrive
2002+
attributes(device) function barrier_arrive(barrier) result(token)
2003+
integer(8), shared :: barrier
2004+
integer(8) :: token
2005+
end function
2006+
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
2007+
integer(8), shared :: barrier
2008+
integer(4), value :: count
2009+
integer(8) :: token
2010+
end function
2011+
end interface
2012+
20012013
interface
20022014
attributes(device) subroutine barrier_init(barrier, count)
20032015
integer(8), shared :: barrier
20042016
integer(4), value :: count
20052017
end subroutine
20062018
end interface
20072019

2008-
interface barrier_arrive
2009-
attributes(device) function barrier_arrive(barrier) result(token)
2020+
interface
2021+
attributes(device) integer function barrier_try_wait(barrier, token)
20102022
integer(8), shared :: barrier
2011-
integer(8) :: token
2023+
integer(8), value :: token
20122024
end function
2013-
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
2025+
end interface
2026+
2027+
interface
2028+
attributes(device) integer function barrier_try_wait_sleep(barrier, token, ns)
20142029
integer(8), shared :: barrier
2015-
integer(4), value :: count
2016-
integer(8) :: token
2030+
integer(8), value :: token
2031+
integer(4), value :: ns
20172032
end function
20182033
end interface
20192034

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,25 @@ end subroutine
492492
! CHECK: %[[CASTED_CMP_XCHG_EV:.*]] = fir.convert %[[CMP_XCHG_EV]] : (i1) -> i32
493493
! CHECK: %{{.*}} = arith.constant 1 : i32
494494
! CHECK: %19 = arith.cmpi eq, %[[CASTED_CMP_XCHG_EV]], %{{.*}} : i32
495+
496+
attributes(global) subroutine test_barrier_try_wait()
497+
integer :: istat
498+
integer(8), shared :: barrier1
499+
integer(8) :: token
500+
istat = barrier_try_wait(barrier1, token)
501+
end subroutine
502+
503+
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait()
504+
! CHECK: scf.while
505+
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %{{.*}}, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %c1000000{{.*}} : !llvm.ptr, i64, i32) -> i32
506+
507+
attributes(global) subroutine test_barrier_try_wait_sleep()
508+
integer :: istat
509+
integer(8), shared :: barrier1
510+
integer(8) :: token
511+
integer(4) :: sleep_time
512+
istat = barrier_try_wait_sleep(barrier1, token, sleep_time)
513+
end subroutine
514+
515+
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
516+
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32

0 commit comments

Comments
 (0)