Skip to content

Commit 8cd0873

Browse files
committed
[Offload] Make olLaunchKernel test thread safe
This sprinkles a few mutexes around the plugin interface so that the olLaunchKernel CTS test now passes when ran on multiple threads. Part of this also involved changing the interface for device synchronise so that it can optionally not free the underlying queue (which introduced a race condition in liboffload).
1 parent 20051b7 commit 8cd0873

File tree

9 files changed

+81
-19
lines changed

9 files changed

+81
-19
lines changed

offload/include/Shared/APITypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <cstddef>
2323
#include <cstdint>
24+
#include <mutex>
2425

2526
extern "C" {
2627

@@ -76,6 +77,9 @@ struct __tgt_async_info {
7677
/// should be freed after finalization.
7778
llvm::SmallVector<void *, 2> AssociatedAllocations;
7879

80+
/// Mutex to guard access to AssociatedAllocations
81+
std::mutex AllocationsMutex;
82+
7983
/// The kernel launch environment used to issue a kernel. Stored here to
8084
/// ensure it is a valid location while the transfer to the device is
8185
/// happening.

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,16 +487,10 @@ Error olSyncQueue_impl(ol_queue_handle_t Queue) {
487487
// Host plugin doesn't have a queue set so it's not safe to call synchronize
488488
// on it, but we have nothing to synchronize in that situation anyway.
489489
if (Queue->AsyncInfo->Queue) {
490-
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo))
490+
if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false))
491491
return Err;
492492
}
493493

494-
// Recreate the stream resource so the queue can be reused
495-
// TODO: Would be easier for the synchronization to (optionally) not release
496-
// it to begin with.
497-
if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo))
498-
return Res;
499-
500494
return Error::success();
501495
}
502496

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
22322232
/// Get the stream of the asynchronous info structure or get a new one.
22332233
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
22342234
AMDGPUStreamTy *&Stream) {
2235+
std::lock_guard<std::mutex> StreamLock{StreamMutex};
22352236
// Get the stream (if any) from the async info.
22362237
Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
22372238
if (!Stream) {
@@ -2296,7 +2297,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
22962297
}
22972298

22982299
/// Synchronize current thread with the pending operations on the async info.
2299-
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
2300+
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
2301+
bool RemoveQueue) override {
23002302
AMDGPUStreamTy *Stream =
23012303
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
23022304
assert(Stream && "Invalid stream");
@@ -2307,8 +2309,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
23072309
// Once the stream is synchronized, return it to stream pool and reset
23082310
// AsyncInfo. This is to make sure the synchronization only works for its
23092311
// own tasks.
2310-
AsyncInfo.Queue = nullptr;
2311-
return AMDGPUStreamManager.returnResource(Stream);
2312+
if (RemoveQueue) {
2313+
AsyncInfo.Queue = nullptr;
2314+
return AMDGPUStreamManager.returnResource(Stream);
2315+
}
2316+
return Plugin::success();
23122317
}
23132318

23142319
/// Query for the completion of the pending operations on the async info.
@@ -3067,6 +3072,9 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
30673072
/// True is the system is configured with XNACK-Enabled.
30683073
/// False otherwise.
30693074
bool IsXnackEnabled = false;
3075+
3076+
/// Mutex to guard getting/setting the stream
3077+
std::mutex StreamMutex;
30703078
};
30713079

30723080
Error AMDGPUDeviceImageTy::loadExecutable(const AMDGPUDeviceTy &Device) {

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ struct AsyncInfoWrapperTy {
138138
/// Register \p Ptr as an associated allocation that is freed after
139139
/// finalization.
140140
void freeAllocationAfterSynchronization(void *Ptr) {
141+
std::lock_guard<std::mutex> AllocationGuard{AsyncInfoPtr->AllocationsMutex};
141142
AsyncInfoPtr->AssociatedAllocations.push_back(Ptr);
142143
}
143144

@@ -828,8 +829,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
828829

829830
/// Synchronize the current thread with the pending operations on the
830831
/// __tgt_async_info structure.
831-
Error synchronize(__tgt_async_info *AsyncInfo);
832-
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo) = 0;
832+
Error synchronize(__tgt_async_info *AsyncInfo, bool RemoveQueue = true);
833+
virtual Error synchronizeImpl(__tgt_async_info &AsyncInfo,
834+
bool RemoveQueue) = 0;
833835

834836
/// Invokes any global constructors on the device if present and is required
835837
/// by the target.
@@ -1591,6 +1593,8 @@ template <typename ResourceRef> class GenericDeviceResourceManagerTy {
15911593
/// Deinitialize the resource pool and delete all resources. This function
15921594
/// must be called before the destructor.
15931595
virtual Error deinit() {
1596+
const std::lock_guard<std::mutex> Lock(Mutex);
1597+
15941598
if (NextAvailable)
15951599
DP("Missing %d resources to be returned\n", NextAvailable);
15961600

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,12 +1335,15 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
13351335
return eraseEntry(*Entry);
13361336
}
13371337

1338-
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) {
1338+
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
1339+
bool RemoveQueue) {
1340+
std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->AllocationsMutex};
1341+
13391342
if (!AsyncInfo || !AsyncInfo->Queue)
13401343
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
13411344
"invalid async info queue");
13421345

1343-
if (auto Err = synchronizeImpl(*AsyncInfo))
1346+
if (auto Err = synchronizeImpl(*AsyncInfo, RemoveQueue))
13441347
return Err;
13451348

13461349
for (auto *Ptr : AsyncInfo->AssociatedAllocations)

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
522522

523523
/// Get the stream of the asynchronous info structure or get a new one.
524524
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
525+
std::lock_guard<std::mutex> StreamLock{StreamMutex};
525526
// Get the stream (if any) from the async info.
526527
Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
527528
if (!Stream) {
@@ -642,17 +643,20 @@ struct CUDADeviceTy : public GenericDeviceTy {
642643
}
643644

644645
/// Synchronize current thread with the pending operations on the async info.
645-
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
646+
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
647+
bool RemoveQueue) override {
646648
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
647649
CUresult Res;
648650
Res = cuStreamSynchronize(Stream);
649651

650652
// Once the stream is synchronized, return it to stream pool and reset
651653
// AsyncInfo. This is to make sure the synchronization only works for its
652654
// own tasks.
653-
AsyncInfo.Queue = nullptr;
654-
if (auto Err = CUDAStreamManager.returnResource(Stream))
655-
return Err;
655+
if (RemoveQueue) {
656+
AsyncInfo.Queue = nullptr;
657+
if (auto Err = CUDAStreamManager.returnResource(Stream))
658+
return Err;
659+
}
656660

657661
return Plugin::check(Res, "error in cuStreamSynchronize: %s");
658662
}
@@ -1289,6 +1293,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
12891293
/// The maximum number of warps that can be resident on all the SMs
12901294
/// simultaneously.
12911295
uint32_t HardwareParallelism = 0;
1296+
1297+
/// Mutex to guard getting/setting the stream
1298+
std::mutex StreamMutex;
12921299
};
12931300

12941301
Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,

offload/plugins-nextgen/host/src/rtl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
297297

298298
/// All functions are already synchronous. No need to do anything on this
299299
/// synchronization function.
300-
Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
300+
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
301+
bool RemoveQueue) override {
301302
return Plugin::success();
302303
}
303304

offload/unittests/OffloadAPI/common/Fixtures.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <OffloadAPI.h>
1010
#include <OffloadPrint.hpp>
1111
#include <gtest/gtest.h>
12+
#include <thread>
1213

1314
#include "Environment.hpp"
1415

@@ -57,6 +58,23 @@ inline std::string SanitizeString(const std::string &Str) {
5758
return NewStr;
5859
}
5960

61+
template <typename Fn> inline void threadify(Fn body) {
62+
std::vector<std::thread> Threads;
63+
for (size_t I = 0; I < 20; I++) {
64+
Threads.emplace_back(
65+
[&body](size_t I) {
66+
std::string ScopeMsg{"Thread #"};
67+
ScopeMsg.append(std::to_string(I));
68+
SCOPED_TRACE(ScopeMsg);
69+
body(I);
70+
},
71+
I);
72+
}
73+
for (auto &T : Threads) {
74+
T.join();
75+
}
76+
}
77+
6078
struct OffloadTest : ::testing::Test {
6179
ol_device_handle_t Host = TestEnvironment::getHostDevice();
6280
};

offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,29 @@ TEST_P(olLaunchKernelFooTest, Success) {
104104
ASSERT_SUCCESS(olMemFree(Mem));
105105
}
106106

107+
TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
108+
threadify([&](size_t) {
109+
void *Mem;
110+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
111+
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
112+
struct {
113+
void *Mem;
114+
} Args{Mem};
115+
116+
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
117+
&LaunchArgs, nullptr));
118+
119+
ASSERT_SUCCESS(olWaitQueue(Queue));
120+
121+
uint32_t *Data = (uint32_t *)Mem;
122+
for (uint32_t i = 0; i < 64; i++) {
123+
ASSERT_EQ(Data[i], i);
124+
}
125+
126+
ASSERT_SUCCESS(olMemFree(Mem));
127+
});
128+
}
129+
107130
TEST_P(olLaunchKernelNoArgsTest, Success) {
108131
ASSERT_SUCCESS(
109132
olLaunchKernel(Queue, Device, Kernel, nullptr, 0, &LaunchArgs));

0 commit comments

Comments
 (0)