Skip to content

Commit 910d7e9

Browse files
authored
[Offload] Make olLaunchKernel test thread safe (#149497)
This sprinkles a few mutexes around the plugin interface so that the olLaunchKernel CTS test now passes when ran on multiple threads. Part of this also involved changing the interface for device synchronise so that it can optionally not free the underlying queue (which introduced a race condition in liboffload).
1 parent 24ea155 commit 910d7e9

File tree

10 files changed

+120
-54
lines changed

10 files changed

+120
-54
lines changed

offload/include/Shared/APITypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <cstddef>
2323
#include <cstdint>
24+
#include <mutex>
2425

2526
extern "C" {
2627

@@ -76,6 +77,9 @@ struct __tgt_async_info {
7677
/// should be freed after finalization.
7778
llvm::SmallVector<void *, 2> AssociatedAllocations;
7879

80+
/// Mutex to guard access to AssociatedAllocations and the Queue.
81+
std::mutex Mutex;
82+
7983
/// The kernel launch environment used to issue a kernel. Stored here to
8084
/// ensure it is a valid location while the transfer to the device is
8185
/// happening.

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ Error initPlugins(OffloadContext &Context) {
208208
}
209209

210210
Error olInit_impl() {
211-
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
211+
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
212212

213213
if (isOffloadInitialized()) {
214214
OffloadContext::get().RefCount++;
@@ -226,7 +226,7 @@ Error olInit_impl() {
226226
}
227227

228228
Error olShutDown_impl() {
229-
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
229+
std::lock_guard<std::mutex> Lock(OffloadContextValMutex);
230230

231231
if (--OffloadContext::get().RefCount != 0)
232232
return Error::success();
@@ -487,16 +487,13 @@ Error olSyncQueue_impl(ol_queue_handle_t Queue) {
487487
// Host plugin doesn't have a queue set so it's not safe to call synchronize
488488
// on it, but we have nothing to synchronize in that situation anyway.
489489
if (Queue->AsyncInfo->Queue) {
490-
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
490+
// We don't need to release the queue and we would like the ability for
491+
// other offload threads to submit work concurrently, so pass "false" here
492+
// so we don't release the underlying queue object.
493+
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
491494
return Err;
492495
}
493496

494-
// Recreate the stream resource so the queue can be reused
495-
// TODO: Would be easier for the synchronization to (optionally) not release
496-
// it to begin with.
497-
if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
498-
return Res;
499-
500497
return Error::success();
501498
}
502499

@@ -747,7 +744,7 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
747744
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
748745
auto &Device = Program->Image->getDevice();
749746

750-
std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
747+
std::lock_guard<std::mutex> Lock(Program->SymbolListMutex);
751748

752749
switch (Kind) {
753750
case OL_SYMBOL_KIND_KERNEL: {

offload/libomptarget/interface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
116116
TargetDataFuncPtrTy TargetDataFunction, const char *RegionTypeMsg,
117117
const char *RegionName) {
118118
assert(PM && "Runtime not initialized");
119-
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
119+
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
120120
"TargetAsyncInfoTy must be convertible to AsyncInfoTy.");
121121

122122
TIMESCOPE_WITH_DETAILS_AND_IDENT("Runtime: Data Copy",
@@ -311,7 +311,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
311311
int32_t ThreadLimit, void *HostPtr,
312312
KernelArgsTy *KernelArgs) {
313313
assert(PM && "Runtime not initialized");
314-
static_assert(std::is_convertible_v<TargetAsyncInfoTy, AsyncInfoTy>,
314+
static_assert(std::is_convertible_v<TargetAsyncInfoTy &, AsyncInfoTy &>,
315315
"Target AsyncInfoTy must be convertible to AsyncInfoTy.");
316316
DP("Entering target region for device %" PRId64 " with entry point " DPxMOD
317317
"\n",

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,16 +2232,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
22322232
/// Get the stream of the asynchronous info structure or get a new one.
22332233
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
22342234
AMDGPUStreamTy *&Stream) {
2235-
// Get the stream (if any) from the async info.
2236-
Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
2237-
if (!Stream) {
2238-
// There was no stream; get an idle one.
2239-
if (auto Err = AMDGPUStreamManager.getResource(Stream))
2240-
return Err;
2241-
2242-
// Modify the async info's stream.
2243-
AsyncInfoWrapper.setQueueAs<AMDGPUStreamTy *>(Stream);
2244-
}
2235+
auto WrapperStream =
2236+
AsyncInfoWrapper.getOrInitQueue<AMDGPUStreamTy *>(AMDGPUStreamManager);
2237+
if (!WrapperStream)
2238+
return WrapperStream.takeError();
2239+
Stream = *WrapperStream;
22452240
return Plugin::success();
22462241
}
22472242

@@ -2296,7 +2291,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
22962291
}
22972292

22982293
/// Synchronize current thread with the pending operations on the async info.
2299-
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
2294+
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
2295+
bool ReleaseQueue) override {
23002296
AMDGPUStreamTy *Stream =
23012297
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
23022298
assert(Stream && "Invalid stream");
@@ -2307,8 +2303,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
23072303
// Once the stream is synchronized, return it to stream pool and reset
23082304
// AsyncInfo. This is to make sure the synchronization only works for its
23092305
// own tasks.
2310-
AsyncInfo.Queue = nullptr;
2311-
return AMDGPUStreamManager.returnResource(Stream);
2306+
if (ReleaseQueue) {
2307+
AsyncInfo.Queue = nullptr;
2308+
return AMDGPUStreamManager.returnResource(Stream);
2309+
}
2310+
return Plugin::success();
23122311
}
23132312

23142313
/// Query for the completion of the pending operations on the async info.

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct GenericPluginTy;
6060
struct GenericKernelTy;
6161
struct GenericDeviceTy;
6262
struct RecordReplayTy;
63+
template <typename ResourceRef> class GenericDeviceResourceManagerTy;
6364

6465
namespace Plugin {
6566
/// Create a success error. This is the same as calling Error::success(), but
@@ -127,6 +128,20 @@ struct AsyncInfoWrapperTy {
127128
AsyncInfoPtr->Queue = Queue;
128129
}
129130

131+
/// Get the queue, using the provided resource manager to initialise it if it
132+
/// doesn't exist.
133+
template <typename Ty, typename RMTy>
134+
Expected<Ty>
135+
getOrInitQueue(GenericDeviceResourceManagerTy<RMTy> &ResourceManager) {
136+
std::lock_guard<std::mutex> Lock(AsyncInfoPtr->Mutex);
137+
if (!AsyncInfoPtr->Queue) {
138+
if (auto Err = ResourceManager.getResource(
139+
*reinterpret_cast<Ty *>(&AsyncInfoPtr->Queue)))
140+
return Err;
141+
}
142+
return getQueueAs<Ty>();
143+
}
144+
130145
/// Synchronize with the __tgt_async_info's pending operations if it's the
131146
/// internal async info. The error associated to the asynchronous operations
132147
/// issued in this queue must be provided in \p Err. This function will update
@@ -138,6 +153,7 @@ struct AsyncInfoWrapperTy {
138153
/// Register \p Ptr as an associated allocation that is freed after
139154
/// finalization.
140155
void freeAllocationAfterSynchronization(void *Ptr) {
156+
std::lock_guard<std::mutex> AllocationGuard(AsyncInfoPtr->Mutex);
141157
AsyncInfoPtr->AssociatedAllocations.push_back(Ptr);
142158
}
143159

@@ -827,9 +843,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
827843
Error setupRPCServer(GenericPluginTy &Plugin, DeviceImageTy &Image);
828844

829845
/// Synchronize the current thread with the pending operations on the
830-
/// __tgt_async_info structure.
831-
Error synchronize(__tgt_async_info *AsyncInfo);
832-
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo) = 0;
846+
/// __tgt_async_info structure. If ReleaseQueue is false, then the
847+
// underlying queue will not be released. In this case, additional
848+
// work may be submitted to the queue whilst a synchronize is running.
849+
Error synchronize(__tgt_async_info *AsyncInfo, bool ReleaseQueue = true);
850+
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo,
851+
bool ReleaseQueue) = 0;
833852

834853
/// Invokes any global constructors on the device if present and is required
835854
/// by the target.

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,18 +1335,25 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
13351335
return eraseEntry(*Entry);
13361336
}
13371337

1338-
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) {
1339-
if (!AsyncInfo || !AsyncInfo->Queue)
1340-
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
1341-
"invalid async info queue");
1338+
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
1339+
bool ReleaseQueue) {
1340+
SmallVector<void *> AllocsToDelete{};
1341+
{
1342+
std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex};
13421343

1343-
if (auto Err = synchronizeImpl(*AsyncInfo))
1344-
return Err;
1344+
if (!AsyncInfo || !AsyncInfo->Queue)
1345+
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
1346+
"invalid async info queue");
1347+
1348+
if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
1349+
return Err;
1350+
1351+
std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations);
1352+
}
13451353

1346-
for (auto *Ptr : AsyncInfo->AssociatedAllocations)
1354+
for (auto *Ptr : AllocsToDelete)
13471355
if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE))
13481356
return Err;
1349-
AsyncInfo->AssociatedAllocations.clear();
13501357

13511358
return Plugin::success();
13521359
}

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -522,16 +522,11 @@ struct CUDADeviceTy : public GenericDeviceTy {
522522

523523
/// Get the stream of the asynchronous info structure or get a new one.
524524
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
525-
// Get the stream (if any) from the async info.
526-
Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
527-
if (!Stream) {
528-
// There was no stream; get an idle one.
529-
if (auto Err = CUDAStreamManager.getResource(Stream))
530-
return Err;
531-
532-
// Modify the async info's stream.
533-
AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
534-
}
525+
auto WrapperStream =
526+
AsyncInfoWrapper.getOrInitQueue<CUstream>(CUDAStreamManager);
527+
if (!WrapperStream)
528+
return WrapperStream.takeError();
529+
Stream = *WrapperStream;
535530
return Plugin::success();
536531
}
537532

@@ -642,17 +637,20 @@ struct CUDADeviceTy : public GenericDeviceTy {
642637
}
643638

644639
/// Synchronize current thread with the pending operations on the async info.
645-
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
640+
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
641+
bool ReleaseQueue) override {
646642
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
647643
CUresult Res;
648644
Res = cuStreamSynchronize(Stream);
649645

650-
// Once the stream is synchronized, return it to stream pool and reset
651-
// AsyncInfo. This is to make sure the synchronization only works for its
652-
// own tasks.
653-
AsyncInfo.Queue = nullptr;
654-
if (auto Err = CUDAStreamManager.returnResource(Stream))
655-
return Err;
646+
// Once the stream is synchronized and we want to release the queue, return
647+
// it to stream pool and reset AsyncInfo. This is to make sure the
648+
// synchronization only works for its own tasks.
649+
if (ReleaseQueue) {
650+
AsyncInfo.Queue = nullptr;
651+
if (auto Err = CUDAStreamManager.returnResource(Stream))
652+
return Err;
653+
}
656654

657655
return Plugin::check(Res, "error in cuStreamSynchronize: %s");
658656
}

offload/plugins-nextgen/host/src/rtl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
297297

298298
/// All functions are already synchronous. No need to do anything on this
299299
/// synchronization function.
300-
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
300+
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
301+
bool ReleaseQueue) override {
301302
return Plugin::success();
302303
}
303304

offload/unittests/OffloadAPI/common/Fixtures.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <OffloadAPI.h>
1010
#include <OffloadPrint.hpp>
1111
#include <gtest/gtest.h>
12+
#include <thread>
1213

1314
#include "Environment.hpp"
1415

@@ -57,6 +58,23 @@ inline std::string SanitizeString(const std::string &Str) {
5758
return NewStr;
5859
}
5960

61+
template <typename Fn> inline void threadify(Fn body) {
62+
std::vector<std::thread> Threads;
63+
for (size_t I = 0; I < 20; I++) {
64+
Threads.emplace_back(
65+
[&body](size_t I) {
66+
std::string ScopeMsg{"Thread #"};
67+
ScopeMsg.append(std::to_string(I));
68+
SCOPED_TRACE(ScopeMsg);
69+
body(I);
70+
},
71+
I);
72+
}
73+
for (auto &T : Threads) {
74+
T.join();
75+
}
76+
}
77+
6078
struct OffloadTest : ::testing::Test {
6179
ol_device_handle_t Host = TestEnvironment::getHostDevice();
6280
};

offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,29 @@ TEST_P(olLaunchKernelFooTest, Success) {
104104
ASSERT_SUCCESS(olMemFree(Mem));
105105
}
106106

107+
TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
108+
threadify([&](size_t) {
109+
void *Mem;
110+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
111+
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
112+
struct {
113+
void *Mem;
114+
} Args{Mem};
115+
116+
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
117+
&LaunchArgs));
118+
119+
ASSERT_SUCCESS(olSyncQueue(Queue));
120+
121+
uint32_t *Data = (uint32_t *)Mem;
122+
for (uint32_t i = 0; i < 64; i++) {
123+
ASSERT_EQ(Data[i], i);
124+
}
125+
126+
ASSERT_SUCCESS(olMemFree(Mem));
127+
});
128+
}
129+
107130
TEST_P(olLaunchKernelNoArgsTest, Success) {
108131
ASSERT_SUCCESS(
109132
olLaunchKernel(Queue, Device, Kernel, nullptr, 0, &LaunchArgs));

0 commit comments

Comments
 (0)