Skip to content

Commit 893e400

Browse files
clementvalDharuniRAcharya
authored andcommitted
[flang][cuda] Add interface and lower barrier_init (llvm#162929)
Add interface for `barrier_init` and lower it to the NVVM Op. https://github.com/clementval/llvm-project/pull/new/cuf_barrier_init
1 parent b9b452d commit 893e400

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ struct IntrinsicLibrary {
208208
fir::ExtendedValue genAssociated(mlir::Type,
209209
llvm::ArrayRef<fir::ExtendedValue>);
210210
mlir::Value genAtand(mlir::Type, llvm::ArrayRef<mlir::Value>);
211+
void genBarrierInit(llvm::ArrayRef<fir::ExtendedValue>);
211212
fir::ExtendedValue genBesselJn(mlir::Type,
212213
llvm::ArrayRef<fir::ExtendedValue>);
213214
fir::ExtendedValue genBesselYn(mlir::Type,

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ static constexpr IntrinsicHandler handlers[]{
346346
&I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>,
347347
{{{"mask", asValue}, {"pred", asValue}}},
348348
/*isElemental=*/false},
349+
{"barrier_init",
350+
&I::genBarrierInit,
351+
{{{"barrier", asAddr}, {"count", asValue}}},
352+
/*isElemental=*/false},
349353
{"bessel_jn",
350354
&I::genBesselJn,
351355
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -3176,6 +3180,22 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
31763180
return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
31773181
}
31783182

3183+
// BARRIER_INIT (CUDA)
3184+
void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
3185+
assert(args.size() == 2);
3186+
auto llvmPtr = fir::ConvertOp::create(
3187+
builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
3188+
fir::getBase(args[0]));
3189+
auto addrCast = mlir::LLVM::AddrSpaceCastOp::create(
3190+
builder, loc,
3191+
mlir::LLVM::LLVMPointerType::get(
3192+
builder.getContext(),
3193+
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared)),
3194+
llvmPtr);
3195+
mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, addrCast,
3196+
fir::getBase(args[1]), {});
3197+
}
3198+
31793199
// BESSEL_JN
31803200
fir::ExtendedValue
31813201
IntrinsicLibrary::genBesselJn(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,13 @@ attributes(device,host) logical function on_device() bind(c)
19871987
end function
19881988
end interface
19891989

1990+
interface
1991+
attributes(device) subroutine barrier_init(barrier, count)
1992+
integer(8) :: barrier
1993+
integer(4) :: count
1994+
end subroutine
1995+
end interface
1996+
19901997
contains
19911998

19921999
attributes(device) subroutine syncthreads()

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -392,28 +392,17 @@ end subroutine
392392
! CHECK: %{{.*}} = nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
393393
! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
394394

395-
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
396-
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
397-
! CHECK-DAG: func.func private @__ldcs_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
398-
! CHECK-DAG: func.func private @__ldlu_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
399-
! CHECK-DAG: func.func private @__ldcv_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
400-
! CHECK-DAG: func.func private @__ldca_i8x2_(!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<2xi64>>)
401-
! CHECK-DAG: func.func private @__ldcg_i8x2_(!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<2xi64>>)
402-
! CHECK-DAG: func.func private @__ldcs_i8x2_(!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<2xi64>>)
403-
! CHECK-DAG: func.func private @__ldlu_i8x2_(!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<2xi64>>)
404-
! CHECK-DAG: func.func private @__ldcv_i8x2_(!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<2xi64>>)
405-
! CHECK-DAG: func.func private @__ldca_r4x4_(!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<4xf32>>)
406-
! CHECK-DAG: func.func private @__ldcg_r4x4_(!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<4xf32>>)
407-
! CHECK-DAG: func.func private @__ldcs_r4x4_(!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<4xf32>>)
408-
! CHECK-DAG: func.func private @__ldlu_r4x4_(!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<4xf32>>)
409-
! CHECK-DAG: func.func private @__ldcv_r4x4_(!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<4xf32>>)
410-
! CHECK-DAG: func.func private @__ldca_r2x2_(!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<2xf16>>)
411-
! CHECK-DAG: func.func private @__ldcg_r2x2_(!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<2xf16>>)
412-
! CHECK-DAG: func.func private @__ldcs_r2x2_(!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<2xf16>>)
413-
! CHECK-DAG: func.func private @__ldlu_r2x2_(!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<2xf16>>)
414-
! CHECK-DAG: func.func private @__ldcv_r2x2_(!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<2xf16>>)
415-
! CHECK-DAG: func.func private @__ldca_r8x2_(!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<2xf64>>)
416-
! CHECK-DAG: func.func private @__ldcg_r8x2_(!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<2xf64>>)
417-
! CHECK-DAG: func.func private @__ldcs_r8x2_(!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<2xf64>>)
418-
! CHECK-DAG: func.func private @__ldlu_r8x2_(!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<2xf64>>)
419-
! CHECK-DAG: func.func private @__ldcv_r8x2_(!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<2xf64>>)
395+
attributes(global) subroutine test_barrier()
396+
integer(8), shared :: barrier
397+
call barrier_init(barrier, 256)
398+
end subroutine
399+
400+
401+
! CHECK-LABEL: func.func @_QPtest_barrier()
402+
403+
! CHECK: %[[SHARED:.*]] = cuf.shared_memory i64 {bindc_name = "barrier", uniq_name = "_QFtest_barrierEbarrier"} -> !fir.ref<i64>
404+
! CHECK: %[[DECL_SHARED:.*]]:2 = hlfir.declare %[[SHARED]] {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_barrierEbarrier"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
405+
! CHECK: %[[COUNT:.*]] = arith.constant 256 : i32
406+
! CHECK: %[[LLVM_PTR:.*]] = fir.convert %[[DECL_SHARED]]#0 : (!fir.ref<i64>) -> !llvm.ptr
407+
! CHECK: %[[SHARED_PTR:.*]] = llvm.addrspacecast %[[LLVM_PTR]] : !llvm.ptr to !llvm.ptr<3>
408+
! CHECK: nvvm.mbarrier.init.shared %[[SHARED_PTR]], %[[COUNT]] : !llvm.ptr<3>, i32

0 commit comments

Comments
 (0)