Skip to content

Commit 2f3b912

Browse files
[Backend] Enable default device sharing across handles for multi threaded handle+device creation (#114)
### Issue The `Multi-threaded Handle creation` test was failing flakily on my recent local runs: ``` /home/srajeshk/fusilli/tests/test_handle.cpp:83: FAILED: REQUIRE( !creationFailed.load() ) with expansion: false ``` The test spawns 32 threads that simultaneously attempt to create CPU handles, and at least one handle creation was failing. ### Root Cause IREE's HAL drivers (both `local-task` for CPU and HIP for GPU) provide a **single default device** per configuration. When multiple threads concurrently called `iree_runtime_instance_try_create_default_device()`, each attempting to create its own device: 1. The first thread succeeded in creating the device 2. Subsequent threads failed because the default device already existed IREE's default device model seems to support one "default" device per driver, not one per handle. ### Fix Implement **device sharing** across handles using a `weak_ptr` caching pattern (similar to the existing `createSharedInstance()` pattern for the runtime instance): **For CPU (`createCPUDevice`):** - Single shared device across all CPU handles - Uses `static std::weak_ptr<iree_hal_device_t>` with mutex protection - First handle creates the device; subsequent handles reuse it **For AMDGPU (`createAMDGPUDevice`):** - Devices cached by `(deviceId, stream)` configuration - Uses `static std::map<key, std::weak_ptr<iree_hal_device_t>>` with mutex protection - Handles with the same configuration share the device; different configurations get separate devices ### Changes | File | Change | |------|--------| | `backend.h` | Added `IreeHalDeviceSharedPtrType`, removed unused `IreeHalDeviceUniquePtrType` | | `handle.h` | Changed `device_` from `unique_ptr` to `shared_ptr` | | `runtime.h` | Implemented device sharing in `createCPUDevice()` and `createAMDGPUDevice()` | | `test_handle.cpp` | Added multi-threaded test for AMDGPU backend | ### Test Plan - [x] `Multi-threaded Handle creation` test passes (CPU backend) - [x] `Multi-threaded Handle creation AMDGPU` test passes (GPU backend) ### Disclaimer - PR description and code fixes generated with assistance from Claude 4.5 Opus under my supervision. --------- Signed-off-by: Sambhav Jain <[email protected]>
1 parent dc4810f commit 2f3b912

File tree

4 files changed

+117
-16
lines changed

4 files changed

+117
-16
lines changed

include/fusilli/backend/backend.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,7 @@ struct IreeHalBufferViewDeleter {
356356
// Aliases for IREE runtime types with custom deleters.
357357
using IreeRuntimeInstanceSharedPtrType =
358358
std::shared_ptr<iree_runtime_instance_t>;
359-
using IreeHalDeviceUniquePtrType =
360-
std::unique_ptr<iree_hal_device_t, IreeHalDeviceDeleter>;
359+
using IreeHalDeviceSharedPtrType = std::shared_ptr<iree_hal_device_t>;
361360
using IreeRuntimeSessionUniquePtrType =
362361
std::unique_ptr<iree_runtime_session_t, IreeRuntimeSessionDeleter>;
363362
using IreeHalBufferViewUniquePtrType =

include/fusilli/backend/handle.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ class Handle {
160160
// `device_` depends on `backend_` and `instance_`.
161161
Backend backend_;
162162
IreeRuntimeInstanceSharedPtrType instance_;
163-
IreeHalDeviceUniquePtrType device_;
163+
// Use shared_ptr to allow sharing devices across handles. For CPU, there's
164+
// only one logical device shared by all handles. For GPU, devices with the
165+
// same (deviceId, stream) configuration are shared.
166+
IreeHalDeviceSharedPtrType device_;
164167
};
165168

166169
} // namespace fusilli

include/fusilli/backend/runtime.h

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <iree/runtime/api.h>
4545

4646
#include <cstdint>
47+
#include <map>
4748
#include <memory>
4849
#include <mutex>
4950
#include <string>
@@ -61,7 +62,7 @@ namespace fusilli {
6162
// Create static singleton IREE runtime instance shared across handles/threads.
6263
inline ErrorOr<IreeRuntimeInstanceSharedPtrType>
6364
Handle::createSharedInstance() {
64-
// Mutex for thread-safe initialization of weakInstance.
65+
// Mutex for thread-safe initialization of the shared instance.
6566
static std::mutex instanceMutex;
6667

6768
// Static weak_ptr to the IREE runtime instance ensures that the
@@ -72,9 +73,7 @@ Handle::createSharedInstance() {
7273
// static variable goes out of scope upon program termination.
7374
static std::weak_ptr<iree_runtime_instance_t> weakInstance;
7475

75-
// If multiple threads simultaneously request a handle, they will
76-
// race into `createSharedInstance()` but only one will succeed in
77-
// creating the instance, and others will use it.
76+
// Serialize access to the weak_ptr check-then-create logic.
7877
std::lock_guard<std::mutex> lock(instanceMutex);
7978

8079
// Try to get the shared_ptr from the weak_ptr (if it exists).
@@ -106,15 +105,38 @@ Handle::createSharedInstance() {
106105
inline ErrorObject Handle::createCPUDevice() {
107106
FUSILLI_LOG_LABEL_ENDL("INFO: Creating per-handle IREE HAL device");
108107

109-
iree_hal_device_t *rawDevice = nullptr;
110-
FUSILLI_CHECK_ERROR(iree_runtime_instance_try_create_default_device(
111-
instance_.get(), iree_make_cstring_view(kHalDriver.at(backend_)),
112-
&rawDevice));
108+
// Mutex for thread-safe access to the shared CPU device.
109+
static std::mutex cpuDeviceMutex;
113110

114-
// Wrap the raw device ptr with a unique_ptr and custom deleter
115-
// for lifetime management.
116-
device_ = IreeHalDeviceUniquePtrType(rawDevice);
111+
// Static weak_ptr to the CPU device ensures that the device is only
112+
// created once and shared across all CPU handles. This is necessary
113+
// because IREE's local-task driver typically provides a single default
114+
// device. The device is released when the last handle using it goes
115+
// out of scope.
116+
static std::weak_ptr<iree_hal_device_t> weakCpuDevice;
117+
118+
// Serialize access to the weak_ptr check-then-create logic.
119+
std::lock_guard<std::mutex> lock(cpuDeviceMutex);
120+
121+
// Try to get the shared_ptr from the weak_ptr (if it exists).
122+
IreeHalDeviceSharedPtrType sharedDevice = weakCpuDevice.lock();
123+
124+
// If weak_ptr expired, create a new CPU device.
125+
if (sharedDevice == nullptr) {
126+
FUSILLI_LOG_LABEL_ENDL("INFO: Creating shared CPU device");
127+
iree_hal_device_t *rawDevice = nullptr;
128+
FUSILLI_CHECK_ERROR(iree_runtime_instance_try_create_default_device(
129+
instance_.get(), iree_make_cstring_view(kHalDriver.at(backend_)),
130+
&rawDevice));
131+
132+
// Wrap the raw device ptr with a shared_ptr and custom deleter
133+
// for lifetime management.
134+
sharedDevice =
135+
IreeHalDeviceSharedPtrType(rawDevice, IreeHalDeviceDeleter());
136+
weakCpuDevice = sharedDevice;
137+
}
117138

139+
device_ = sharedDevice;
118140
return ok();
119141
}
120142

@@ -130,6 +152,36 @@ inline ErrorObject Handle::createAMDGPUDevice(int deviceId, uintptr_t stream) {
130152
// when building Fusilli without AMDGPU support (which disables IREE HAL
131153
// HIP driver from being built).
132154
#ifdef FUSILLI_ENABLE_AMDGPU
155+
// Mutex for thread-safe access to the GPU device cache.
156+
static std::mutex gpuDeviceMutex;
157+
158+
// Cache key for AMDGPU devices: (deviceId, stream) pair.
159+
// Devices with the same configuration are shared across handles.
160+
using GpuDeviceKey = std::pair<int, uintptr_t>;
161+
static std::map<GpuDeviceKey, std::weak_ptr<iree_hal_device_t>>
162+
gpuDeviceCache;
163+
164+
// Serialize access to the cache check-then-create logic.
165+
std::lock_guard<std::mutex> lock(gpuDeviceMutex);
166+
167+
// Clean up all expired entries while we hold the lock.
168+
std::erase_if(gpuDeviceCache, // C++20
169+
[](const auto &entry) { return entry.second.expired(); });
170+
171+
GpuDeviceKey key{deviceId, stream};
172+
173+
// Try to get an existing device from the cache.
174+
if (auto it = gpuDeviceCache.find(key); it != gpuDeviceCache.end()) {
175+
// Entry exists and is valid (we just cleaned expired ones).
176+
FUSILLI_LOG_LABEL_ENDL("INFO: Reusing cached AMDGPU device");
177+
// Lock the weak_ptr to get a shared_ptr and assign to device_.
178+
device_ = it->second.lock();
179+
return ok();
180+
}
181+
182+
// Create a new device since none exists in cache for this configuration.
183+
FUSILLI_LOG_LABEL_ENDL("INFO: Creating new AMDGPU device");
184+
133185
// Device parms.
134186
iree_hal_hip_device_params_t params;
135187
setDefaultIreeHalHipDeviceParams(&params);
@@ -149,9 +201,14 @@ inline ErrorObject Handle::createAMDGPUDevice(int deviceId, uintptr_t stream) {
149201
driver, HIP_DEVICE_ID_TO_IREE_DEVICE_ID(deviceId), /*param_count=*/0,
150202
/*params=*/nullptr, iree_allocator_system(), &rawDevice));
151203

152-
// Wrap the raw device ptr with a unique_ptr and custom deleter
204+
// Wrap the raw device ptr with a shared_ptr and custom deleter
153205
// for lifetime management.
154-
device_ = IreeHalDeviceUniquePtrType(rawDevice);
206+
IreeHalDeviceSharedPtrType sharedDevice(rawDevice, IreeHalDeviceDeleter());
207+
208+
// Cache the device for future handles with the same configuration.
209+
gpuDeviceCache[key] = sharedDevice;
210+
211+
device_ = sharedDevice;
155212
return ok();
156213
#else
157214
return ErrorObject(ErrorCode::InternalError,

tests/test_handle.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,48 @@ TEST_CASE("Multi-threaded Handle creation", "[handle][thread]") {
8484
REQUIRE(handles.size() == kNumThreads);
8585
}
8686

87+
#ifdef FUSILLI_ENABLE_AMDGPU
88+
TEST_CASE("Multi-threaded Handle creation AMDGPU", "[handle][thread][amdgpu]") {
89+
constexpr int kNumThreads = 32;
90+
91+
std::vector<std::thread> threads;
92+
threads.reserve(kNumThreads);
93+
94+
std::vector<Handle> handles;
95+
handles.reserve(kNumThreads);
96+
97+
// Create a barrier to force threads to start simultaneously.
98+
std::barrier startBarrier(kNumThreads);
99+
100+
// Atomic flag to track failures during handle creation.
101+
std::atomic<bool> creationFailed{false};
102+
103+
// Mutex for pushing to handles in a thread-safe manner.
104+
std::mutex handlesMutex;
105+
106+
for (size_t i = 0; i < kNumThreads; ++i) {
107+
threads.emplace_back([&]() {
108+
// Wait at the barrier until all threads reach this point.
109+
startBarrier.arrive_and_wait();
110+
// Create the handle for AMDGPU backend.
111+
auto handleOrError = Handle::create(Backend::AMDGPU);
112+
if (isError(handleOrError)) {
113+
creationFailed.store(true);
114+
return;
115+
}
116+
std::lock_guard<std::mutex> lock(handlesMutex);
117+
handles.push_back(std::move(*handleOrError));
118+
});
119+
}
120+
// Wait for all threads to finish.
121+
for (auto &t : threads)
122+
t.join();
123+
124+
REQUIRE(!creationFailed.load());
125+
REQUIRE(handles.size() == kNumThreads);
126+
}
127+
#endif
128+
87129
TEST_CASE("Handle creation with deviceId and stream, CPU backend should fail",
88130
"[handle]") {
89131
// Attempting to create CPU handle with with a specific deviceId and stream

0 commit comments

Comments
 (0)