Skip to content

Commit c417046

Browse files
authored
[flang][cuda] Lower set/get default stream for arrays (llvm#181432)
1 parent 49fa2a4 commit c417046

File tree

8 files changed

+127
-10
lines changed

8 files changed

+127
-10
lines changed

flang-rt/lib/cuda/allocator.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,9 @@ cudaStream_t RTDECL(CUFGetAssociatedStream)(void *p) {
141141
return nullptr;
142142
}
143143

144-
int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream, bool hasStat,
145-
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
146-
Terminator terminator{sourceFile, sourceLine};
144+
int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream) {
147145
if (p == nullptr) {
148-
return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
146+
return StatBaseNull;
149147
}
150148
int pos = findAllocation(p);
151149
if (pos >= 0) {

flang-rt/unittests/Runtime/CUDA/Allocatable.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ TEST(AllocatableAsyncTest, SetStreamTest) {
205205

206206
// REAL(4), DEVICE, ALLOCATABLE :: b(:) - unallocated, base_addr is null
207207
auto b{createAllocatable(TypeCategory::Real, 4)};
208-
int stat2 = RTDECL(CUFSetAssociatedStream)(
209-
b->raw().base_addr, stream, true, nullptr, __FILE__, __LINE__);
208+
int stat2 = RTDECL(CUFSetAssociatedStream)(b->raw().base_addr, stream);
210209
EXPECT_EQ(stat2, StatBaseNull);
211210
}

flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
5151
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
5252
mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef<mlir::Value>);
5353
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
54+
fir::ExtendedValue
55+
genCUDASetDefaultStreamArray(mlir::Type,
56+
llvm::ArrayRef<fir::ExtendedValue>);
57+
fir::ExtendedValue
58+
genCUDAGetDefaultStreamArg(mlir::Type,
59+
llvm::ArrayRef<fir::ExtendedValue>);
5460
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
5561
template <const char *fctName, int extent>
5662
fir::ExtendedValue genLDXXFunc(mlir::Type,

flang/include/flang/Runtime/CUDA/allocator.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ extern "C" {
2121

2222
void RTDECL(CUFRegisterAllocator)();
2323
cudaStream_t RTDECL(CUFGetAssociatedStream)(void *);
24-
int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t, bool hasStat = false,
25-
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
26-
int sourceLine = 0);
24+
int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t);
2725
}
2826

2927
void *CUFAllocPinned(std::size_t, std::int64_t *);

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Optimizer/Builder/MutableBox.h"
2020
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
2121
#include "flang/Optimizer/HLFIR/HLFIROps.h"
22+
#include "flang/Runtime/entry-names.h"
2223
#include "mlir/Dialect/Index/IR/IndexOps.h"
2324
#include "mlir/Dialect/SCF/IR/SCF.h"
2425
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -382,6 +383,16 @@ static constexpr IntrinsicHandler cudaHandlers[]{
382383
&CI::genClusterDimBlocks),
383384
{},
384385
/*isElemental=*/false},
386+
{"cudagetstreamdefaultarg",
387+
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
388+
&CI::genCUDAGetDefaultStreamArg),
389+
{{{"devptr", asAddr}}},
390+
/*isElemental=*/false},
391+
{"cudasetstreamarray",
392+
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
393+
&CI::genCUDASetDefaultStreamArray),
394+
{{{"devptr", asAddr}, {"stream", asValue}}},
395+
/*isElemental=*/false},
385396
{"fence_proxy_async",
386397
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
387398
&CI::genFenceProxyAsync),
@@ -1103,6 +1114,46 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
11031114
return res;
11041115
}
11051116

1117+
// CUDASETSTREAMARRAY
1118+
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStreamArray(
1119+
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
1120+
assert(args.size() == 2);
1121+
mlir::Value arg = fir::getBase(args[0]);
1122+
mlir::Value stream = fir::getBase(args[1]);
1123+
1124+
if (mlir::isa<fir::BaseBoxType>(arg.getType()))
1125+
arg = fir::BoxAddrOp::create(builder, loc, arg);
1126+
mlir::Type i64Ty = builder.getI64Type();
1127+
mlir::Type i32Ty = builder.getI32Type();
1128+
auto ctx = builder.getContext();
1129+
mlir::Type voidPtrTy =
1130+
fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
1131+
mlir::FunctionType ftype =
1132+
mlir::FunctionType::get(ctx, {voidPtrTy, i64Ty}, {i32Ty});
1133+
mlir::Value voidPtr = builder.createConvert(loc, voidPtrTy, arg);
1134+
auto funcOp =
1135+
builder.createFunction(loc, RTNAME_STRING(CUFSetAssociatedStream), ftype);
1136+
auto call = fir::CallOp::create(builder, loc, funcOp, {voidPtr, stream});
1137+
return call.getResult(0);
1138+
}
1139+
1140+
// CUDAGETDEFAULTSTREAMARG
1141+
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDAGetDefaultStreamArg(
1142+
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {
1143+
assert(args.size() == 1);
1144+
mlir::Value devptr = fir::getBase(args[0]);
1145+
mlir::Type i64Ty = builder.getI64Type();
1146+
auto ctx = builder.getContext();
1147+
mlir::Type voidPtrTy =
1148+
fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
1149+
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {voidPtrTy}, {i64Ty});
1150+
mlir::Value voidPtr = builder.createConvert(loc, voidPtrTy, devptr);
1151+
auto funcOp =
1152+
builder.createFunction(loc, RTNAME_STRING(CUFGetAssociatedStream), ftype);
1153+
auto call = fir::CallOp::create(builder, loc, funcOp, {voidPtr});
1154+
return call.getResult(0);
1155+
}
1156+
11061157
// FENCE_PROXY_ASYNC
11071158
void CUDAIntrinsicLibrary::genFenceProxyAsync(
11081159
llvm::ArrayRef<fir::ExtendedValue> args) {

flang/module/cuda_runtime_api.f90

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
!===-- module/cuda_runtime_api.f90 -----------------------------------------===!
2+
!
3+
! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
! See https://llvm.org/LICENSE.txt for license information.
5+
! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
!
7+
!===------------------------------------------------------------------------===!
8+
9+
module cuda_runtime_api
10+
implicit none
11+
12+
integer, parameter :: cuda_stream_kind = int_ptr_kind()
13+
14+
interface cudaforgetdefaultstream
15+
integer(kind=cuda_stream_kind) function cudagetstreamdefaultarg(devptr)
16+
import cuda_stream_kind
17+
!DIR$ IGNORE_TKR (TKR) devptr
18+
integer, device :: devptr(*)
19+
end function
20+
integer(kind=cuda_stream_kind) function cudastreamgetdefaultnull()
21+
import cuda_stream_kind
22+
end function
23+
end interface
24+
25+
interface cudaforsetdefaultstream
26+
integer function cudasetdefaultstream(stream)
27+
import cuda_stream_kind
28+
!DIR$ IGNORE_TKR (K) stream
29+
integer(kind=cuda_stream_kind), value :: stream
30+
end function
31+
integer function cudasetstreamarray(devptr, stream)
32+
import cuda_stream_kind
33+
!DIR$ IGNORE_TKR (K) stream, (TKR) devptr
34+
integer, device :: devptr(*)
35+
integer(kind=cuda_stream_kind), value :: stream
36+
end function
37+
end interface
38+
39+
end module cuda_runtime_api
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
subroutine associated_stream
4+
use cuda_runtime_api
5+
integer(kind=cuda_stream_kind) :: strm, strmout
6+
integer, managed, allocatable :: v(:)
7+
integer :: istat
8+
9+
istat = cudaforSetDefaultStream(v, strm)
10+
strmout = cudaforGetDefaultStream(v)
11+
12+
end subroutine
13+
14+
! CHECK-LABEL: func.func @_QPassociated_stream()
15+
! CHECK: %[[ADDR:.*]] = fir.box_addr %{{.*}} : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
16+
! CHECK: %[[STREAM:.*]] = fir.load %{{.*}}#0 : !fir.ref<i64>
17+
! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap<!fir.array<?xi32>>) -> !fir.llvm_ptr<i8>
18+
! CHECK: %[[STAT:.*]] = fir.call @_FortranACUFSetAssociatedStream(%[[VOIDPTR]], %[[STREAM]]) fastmath<contract> : (!fir.llvm_ptr<i8>, i64) -> i32
19+
! CHECK: hlfir.assign %[[STAT]] to %{{.*}}#0 : i32, !fir.ref<i32>
20+
21+
! CHECK: %[[ADDR:.*]] = fir.box_addr %{{.*}} : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
22+
! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap<!fir.array<?xi32>>) -> !fir.llvm_ptr<i8>
23+
! CHECK: %[[STREAM:.*]] = fir.call @_FortranACUFGetAssociatedStream(%[[VOIDPTR]]) fastmath<contract> : (!fir.llvm_ptr<i8>) -> i64
24+
! CHECK: hlfir.assign %[[STREAM]] to %{{.*}}#0 : i64, !fir.ref<i64>

flang/tools/f18/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(MODULES
1616
"__cuda_builtins"
1717
"__cuda_device"
1818
"cooperative_groups"
19+
"cuda_runtime_api"
1920
"cudadevice"
2021
"ieee_arithmetic"
2122
"ieee_exceptions"
@@ -64,7 +65,8 @@ if (NOT CMAKE_CROSSCOMPILING)
6465
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
6566
elseif(${filename} STREQUAL "__cuda_device" OR
6667
${filename} STREQUAL "cudadevice" OR
67-
${filename} STREQUAL "cooperative_groups")
68+
${filename} STREQUAL "cooperative_groups" OR
69+
${filename} STREQUAL "cuda_runtime_api")
6870
set(opts -fc1 -xcuda)
6971
if(${filename} STREQUAL "__cuda_device")
7072
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)

0 commit comments

Comments
 (0)