diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 65cd32d30f..90a6331441 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -2,6 +2,7 @@ #include #include +#include #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION @@ -362,6 +363,19 @@ Device::~Device() { device_->release(); } +DeviceStream& Device::get_stream_nolock(int index) { + auto it = stream_map_.find(index); + if (it == stream_map_.end()) { + throw std::out_of_range("[metal::Device] Invalid stream index requested."); + } + return *it->second; +} + +DeviceStream* Device::get_stream_ptr(int index) { + std::shared_lock lk(stream_map_mtx_); + return &get_stream_nolock(index); +} + void Device::new_queue(int index) { auto thread_pool = metal::new_scoped_memory_pool(); auto q = device_->newCommandQueue(); @@ -370,24 +384,16 @@ void Device::new_queue(int index) { throw std::runtime_error( "[metal::Device] Failed to make new command queue."); } - stream_map_.emplace(index, q); + { + std::unique_lock lk(stream_map_mtx_); + stream_map_.emplace(index, std::make_unique(q)); + } if (residency_set_ != nullptr) { q->addResidencySet(residency_set_); } } -MTL::CommandQueue* Device::get_queue(Stream stream) { - return get_stream_(stream.index).queue; -} - -bool Device::command_buffer_needs_commit(int index) { - auto& stream = get_stream_(index); - return (stream.buffer_ops > max_ops_per_buffer_) || - ((stream.buffer_sizes >> 20) > max_mb_per_buffer_); -} - -MTL::CommandBuffer* Device::get_command_buffer(int index) { - auto& stream = get_stream_(index); +MTL::CommandBuffer* Device::ensure_command_buffer(DeviceStream& stream) { if (stream.buffer == nullptr) { stream.buffer = stream.queue->commandBufferWithUnretainedReferences(); if (!stream.buffer) { @@ -400,33 +406,50 @@ MTL::CommandBuffer* Device::get_command_buffer(int index) { return stream.buffer; } +MTL::CommandQueue* Device::get_queue(Stream stream) { + auto* stream_ref = get_stream_ptr(stream.index); + return stream_ref->queue; +} + +bool Device::command_buffer_needs_commit(int index) { + auto* stream = get_stream_ptr(index); + return (stream->buffer_ops > max_ops_per_buffer_) || + ((stream->buffer_sizes >> 20) > max_mb_per_buffer_); +} + +MTL::CommandBuffer* Device::get_command_buffer(int index) { + auto* stream = get_stream_ptr(index); + return ensure_command_buffer(*stream); +} + void Device::commit_command_buffer(int index) { - auto& stream = get_stream_(index); - stream.buffer->commit(); - stream.buffer->release(); - stream.buffer = nullptr; - stream.buffer_ops = 0; - stream.buffer_sizes = 0; + auto* stream = get_stream_ptr(index); + stream->buffer->commit(); + stream->buffer->release(); + stream->buffer = nullptr; + stream->buffer_ops = 0; + stream->buffer_sizes = 0; } void Device::add_temporary(array arr, int index) { - get_stream_(index).temporaries.push_back(std::move(arr)); + auto* stream = get_stream_ptr(index); + stream->temporaries.push_back(std::move(arr)); } void Device::add_temporaries(std::vector arrays, int index) { if (arrays.empty()) { return; } - auto& stream = get_stream_(index); - stream.temporaries.insert( - stream.temporaries.end(), + auto* stream = get_stream_ptr(index); + stream->temporaries.insert( + stream->temporaries.end(), std::make_move_iterator(arrays.begin()), std::make_move_iterator(arrays.end())); } void Device::end_encoding(int index) { - auto& stream = get_stream_(index); - if (stream.encoder != nullptr) { + auto* stream = get_stream_ptr(index); + if (stream->encoder != nullptr) { // Each command encoder has a unique fence. We also store a map of // all previous outputs of command encoders to their corresponding fence. // - The command encoder records its inputs and outputs. @@ -439,9 +462,9 @@ void Device::end_encoding(int index) { // - Temporaries are a special case as they do not cross command encoder // boundaries. These can be removed early from the encoders inputs and // outputs since they don't need synchronization. - auto& enc = *stream.encoder; + auto& enc = *stream->encoder; // Remove temporaries from inputs and outputs - for (auto& t : stream.temporaries) { + for (auto& t : stream->temporaries) { enc.outputs().erase(t.buffer().ptr()); enc.inputs().erase(t.buffer().ptr()); } @@ -450,9 +473,9 @@ void Device::end_encoding(int index) { // in the completion handler so they are not prematurely released std::unordered_set> waiting_on; { - std::lock_guard lk(stream.fence_mtx); + std::lock_guard lk(stream->fence_mtx); for (auto in : enc.inputs()) { - if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { + if (auto it = stream->outputs.find(in); it != stream->outputs.end()) { // If we've already waited on a fence, don't wait on it again. if (waiting_on.find(it->second) == waiting_on.end()) { enc.wait_for_fence(it->second->fence); @@ -461,42 +484,40 @@ void Device::end_encoding(int index) { } } for (auto out : enc.outputs()) { - stream.outputs[out] = stream.fence; + stream->outputs[out] = stream->fence; } } - enc.update_fence(stream.fence->fence); - stream.buffer->addCompletedHandler( - [&stream, - waiting_on = std::move(waiting_on), - fence = std::move(stream.fence), - outputs = std::move(enc.outputs()), - temporaries = - std::move(stream.temporaries)](MTL::CommandBuffer*) mutable { - temporaries.clear(); - std::lock_guard lk(stream.fence_mtx); - for (auto o : outputs) { - if (auto it = stream.outputs.find(o); it != stream.outputs.end()) { - if (it->second == fence) { - stream.outputs.erase(it); - } - } + enc.update_fence(stream->fence->fence); + stream->buffer->addCompletedHandler([stream, + waiting_on = std::move(waiting_on), + fence = std::move(stream->fence), + outputs = std::move(enc.outputs()), + temporaries = + std::move(stream->temporaries)]( + MTL::CommandBuffer*) mutable { + temporaries.clear(); + std::lock_guard lk(stream->fence_mtx); + for (auto o : outputs) { + if (auto it = stream->outputs.find(o); it != stream->outputs.end()) { + if (it->second == fence) { + stream->outputs.erase(it); } - }); + } + } + }); } - stream.encoder = nullptr; + stream->encoder = nullptr; } CommandEncoder& Device::get_command_encoder(int index) { - auto& stream = get_stream_(index); - if (stream.encoder == nullptr) { + auto* stream = get_stream_ptr(index); + if (stream->encoder == nullptr) { // Ensure there is an active command buffer - if (stream.buffer == nullptr) { - get_command_buffer(index); - } - stream.encoder = std::make_unique(stream); - stream.fence = std::make_shared(device_->newFence()); + ensure_command_buffer(*stream); + stream->encoder = std::make_unique(*stream); + stream->fence = std::make_shared(device_->newFence()); } - return *stream.encoder; + return *stream->encoder; } MTL::Library* Device::get_library( @@ -742,8 +763,12 @@ MTL::ComputePipelineState* Device::get_kernel_( auto mtl_linked_funcs = get_linked_functions_(linked_functions); auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs); - mtl_function->release(); - mtl_linked_funcs->release(); + if (mtl_function) { + mtl_function->release(); + } + if (mtl_linked_funcs) { + mtl_linked_funcs->release(); + } // Add kernel to cache kernel_map_.insert({hash_name, kernel}); @@ -790,8 +815,9 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) { } residency_set_ = residency_set; // Attach residency set to existing command queues + std::shared_lock lk(stream_map_mtx_); for (auto& [_, stream] : stream_map_) { - stream.queue->addResidencySet(residency_set_); + stream->queue->addResidencySet(residency_set_); } } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 746fbc0888..5b96df9776 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -209,7 +210,7 @@ class Device { private: DeviceStream& get_stream_(int index) { - return stream_map_.find(index)->second; + return *stream_map_.find(index)->second; } MTL::Library* get_library_cache_(const std::string& name); @@ -243,8 +244,13 @@ class Device { const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); + DeviceStream& get_stream_nolock(int index); + DeviceStream* get_stream_ptr(int index); + MTL::CommandBuffer* ensure_command_buffer(DeviceStream& stream); + MTL::Device* device_; - std::unordered_map stream_map_; + mutable std::shared_mutex stream_map_mtx_; + std::unordered_map> stream_map_; std::shared_mutex kernel_mtx_; std::shared_mutex library_mtx_; diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index b19f6434a1..423499ee5d 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -55,6 +55,11 @@ void synchronize() { namespace scheduler { +std::mutex& stream_creation_mutex() { + static std::mutex mutex; + return mutex; +} + /** A singleton scheduler to manage devices, streams, and task execution. */ Scheduler& scheduler() { // Leak the scheduler on Windows to avoid joining threads on exit, can be diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 877fdd5f6a..b1bfdcedea 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -2,8 +2,8 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -14,6 +14,8 @@ namespace mlx::core::scheduler { +std::mutex& stream_creation_mutex(); + struct StreamThread { std::mutex mtx; std::queue> q; @@ -79,14 +81,22 @@ class Scheduler { Scheduler& operator=(Scheduler&&) = delete; Stream new_stream(const Device& d) { - streams_.emplace_back(streams_.size(), d); + // Lock the mutex to ensure that the stream is created in a thread-safe + // manner This is necessary because the stream creation is not thread-safe + std::lock_guard lk(stream_creation_mutex()); + const auto new_stream_index = static_cast(streams_.size()); + Stream stream(new_stream_index, d); + streams_.push_back(stream); + + // Create the stream (GPU) or thread (CPU) if (d == Device::gpu) { threads_.push_back(nullptr); - gpu::new_stream(streams_.back()); + gpu::new_stream(stream); } else { - threads_.push_back(new StreamThread{}); + auto stream_thread = std::make_unique(); + threads_.push_back(stream_thread.release()); } - return streams_.back(); + return stream; } template diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index bbf97d05c4..b1df38417b 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -2,6 +2,12 @@ #include "doctest/doctest.h" +#include +#include +#include +#include +#include + #include "mlx/mlx.h" #include "mlx/scheduler.h" @@ -107,6 +113,36 @@ TEST_CASE("test stream placement") { } } +TEST_CASE("test concurrent stream creation") { + constexpr int kNumThreads = 16; + std::promise go; + auto start = go.get_future().share(); + + std::mutex results_mtx; + std::vector indices; + indices.reserve(kNumThreads); + std::vector threads; + threads.reserve(kNumThreads); + + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&]() { + start.wait(); + auto s = new_stream(Device::cpu); + std::lock_guard lk(results_mtx); + indices.push_back(s.index); + }); + } + + go.set_value(); + for (auto& t : threads) { + t.join(); + } + + CHECK_EQ(indices.size(), static_cast(kNumThreads)); + std::unordered_set unique_indices(indices.begin(), indices.end()); + CHECK_EQ(unique_indices.size(), indices.size()); +} + TEST_CASE("test scheduler races") { auto x = zeros({1}); auto y = zeros({100});