Skip to content

Commit 9a22f55

Browse files
committed
[flang][cuda] Add interfaces and lowering for barrier_try_wait(_sleep)
1 parent 83f751a commit 9a22f55

File tree

4 files changed

+105
-6
lines changed

4 files changed

+105
-6
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ struct IntrinsicLibrary {
211211
mlir::Value genBarrierArrive(mlir::Type, llvm::ArrayRef<mlir::Value>);
212212
mlir::Value genBarrierArriveCnt(mlir::Type, llvm::ArrayRef<mlir::Value>);
213213
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
214+
mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef<mlir::Value>);
215+
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
214216
fir::ExtendedValue genBesselJn(mlir::Type,
215217
llvm::ArrayRef<fir::ExtendedValue>);
216218
fir::ExtendedValue genBesselYn(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5151
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
5252
#include "mlir/Dialect/Math/IR/Math.h"
53+
#include "mlir/Dialect/SCF/IR/SCF.h"
5354
#include "mlir/Dialect/Vector/IR/VectorOps.h"
5455
#include "llvm/Support/CommandLine.h"
5556
#include "llvm/Support/Debug.h"
@@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{
358359
&I::genBarrierInit,
359360
{{{"barrier", asAddr}, {"count", asValue}}},
360361
/*isElemental=*/false},
362+
{"barrier_try_wait",
363+
&I::genBarrierTryWait,
364+
{{{"barrier", asAddr}, {"token", asValue}}},
365+
/*isElemental=*/false},
366+
{"barrier_try_wait_sleep",
367+
&I::genBarrierTryWaitSleep,
368+
{{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}},
369+
/*isElemental=*/false},
361370
{"bessel_jn",
362371
&I::genBesselJn,
363372
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -3280,6 +3289,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
32803289
mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
32813290
}
32823291

3292+
// BARRIER_TRY_WAIT (CUDA)
3293+
mlir::Value
3294+
IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType,
3295+
llvm::ArrayRef<mlir::Value> args) {
3296+
assert(args.size() == 2);
3297+
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
3298+
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
3299+
fir::StoreOp::create(builder, loc, zero, res);
3300+
mlir::Value ns =
3301+
builder.createIntegerConstant(loc, builder.getI32Type(), 1000000);
3302+
mlir::Value load = fir::LoadOp::create(builder, loc, res);
3303+
auto whileOp = mlir::scf::WhileOp::create(
3304+
builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load});
3305+
mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore());
3306+
mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc);
3307+
builder.setInsertionPointToStart(beforeBlock);
3308+
mlir::Value condition = mlir::arith::CmpIOp::create(
3309+
builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero);
3310+
mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg);
3311+
mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter());
3312+
mlir::Value afterArg = afterBlock->addArgument(resultType, loc);
3313+
builder.setInsertionPointToStart(afterBlock);
3314+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
3315+
auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
3316+
mlir::Value ret =
3317+
mlir::NVVM::InlinePtxOp::create(
3318+
builder, loc, {resultType}, {barrier, args[1], ns}, {},
3319+
".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
3320+
"selp.b32 %0, 1, 0, p;",
3321+
{})
3322+
.getResult(0);
3323+
mlir::scf::YieldOp::create(builder, loc, ret);
3324+
builder.setInsertionPointAfter(whileOp);
3325+
return whileOp.getResult(0);
3326+
}
3327+
3328+
// BARRIER_TRY_WAIT_SLEEP (CUDA)
3329+
mlir::Value
3330+
IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
3331+
llvm::ArrayRef<mlir::Value> args) {
3332+
assert(args.size() == 3);
3333+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
3334+
auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
3335+
return mlir::NVVM::InlinePtxOp::create(
3336+
builder, loc, {resultType}, {barrier, args[1], args[2]}, {},
3337+
".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; "
3338+
"selp.b32 %0, 1, 0, p;",
3339+
{})
3340+
.getResult(0);
3341+
}
3342+
32833343
// BESSEL_JN
32843344
fir::ExtendedValue
32853345
IntrinsicLibrary::genBesselJn(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,22 +1998,37 @@ attributes(device,host) logical function on_device() bind(c)
19981998

19991999
! TMA Operations
20002000

2001+
interface barrier_arrive
2002+
attributes(device) function barrier_arrive(barrier) result(token)
2003+
integer(8), shared :: barrier
2004+
integer(8) :: token
2005+
end function
2006+
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
2007+
integer(8), shared :: barrier
2008+
integer(4), value :: count
2009+
integer(8) :: token
2010+
end function
2011+
end interface
2012+
20012013
interface
20022014
attributes(device) subroutine barrier_init(barrier, count)
20032015
integer(8), shared :: barrier
20042016
integer(4), value :: count
20052017
end subroutine
20062018
end interface
20072019

2008-
interface barrier_arrive
2009-
attributes(device) function barrier_arrive(barrier) result(token)
2020+
interface
2021+
attributes(device) integer function barrier_try_wait(barrier, token)
20102022
integer(8), shared :: barrier
2011-
integer(8) :: token
2023+
integer(8), value :: token
20122024
end function
2013-
attributes(device) function barrier_arrive_cnt(barrier, count) result(token)
2025+
end interface
2026+
2027+
interface
2028+
attributes(device) integer function barrier_try_wait_sleep(barrier, token, ns)
20142029
integer(8), shared :: barrier
2015-
integer(4), value :: count
2016-
integer(8) :: token
2030+
integer(8), value :: token
2031+
integer(4), value :: ns
20172032
end function
20182033
end interface
20192034

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,25 @@ end subroutine
479479

480480
! CHECK-LABEL: func.func @_QPtest_bulk_s2g
481481
! CHECL: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
482+
483+
attributes(global) subroutine test_barrier_try_wait()
484+
integer :: istat
485+
integer(8), shared :: barrier1
486+
integer(8) :: token
487+
istat = barrier_try_wait(barrier1, token)
488+
end subroutine
489+
490+
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait()
491+
! CHECK: scf.while
492+
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %{{.*}}, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %c1000000{{.*}} : !llvm.ptr, i64, i32) -> i32
493+
494+
attributes(global) subroutine test_barrier_try_wait_sleep()
495+
integer :: istat
496+
integer(8), shared :: barrier1
497+
integer(8) :: token
498+
integer(4) :: sleep_time
499+
istat = barrier_try_wait_sleep(barrier1, token, sleep_time)
500+
end subroutine
501+
502+
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
503+
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32

0 commit comments

Comments
 (0)