From 77d465840552c679d3fe1cea58858a597d8024b5 Mon Sep 17 00:00:00 2001 From: Jack Date: Wed, 5 Nov 2025 17:45:40 -0500 Subject: [PATCH 1/4] . --- mlx/backend/metal/device.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 6d4d2841d7..c03475109c 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -750,8 +750,12 @@ MTL::ComputePipelineState* Device::get_kernel_( auto mtl_linked_funcs = get_linked_functions_(linked_functions); auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs); - mtl_function->release(); - mtl_linked_funcs->release(); + if (mtl_function) { + mtl_function->release(); + } + if (mtl_linked_funcs) { + mtl_linked_funcs->release(); + } // Add kernel to cache kernel_map_.insert({hash_name, kernel}); From 84ffced5b971f3009056122283c235b7a610b4a4 Mon Sep 17 00:00:00 2001 From: Jack Date: Thu, 20 Nov 2025 09:54:39 -0500 Subject: [PATCH 2/4] . --- mlx/backend/metal/device.cpp | 119 +++++++++++++++++++++-------------- mlx/backend/metal/device.h | 10 ++- mlx/scheduler.cpp | 5 ++ mlx/scheduler.h | 29 +++++++-- tests/scheduler_tests.cpp | 36 +++++++++++ 5 files changed, 143 insertions(+), 56 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 600ac03c05..f4dec76ea7 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -2,6 +2,7 @@ #include #include +#include #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION @@ -362,6 +363,20 @@ Device::~Device() { device_->release(); } +DeviceStream& Device::get_stream_nolock(int index) { + auto it = stream_map_.find(index); + if (it == stream_map_.end()) { + throw std::out_of_range( + "[metal::Device] Invalid stream index requested."); + } + return *it->second; +} + +DeviceStream* Device::get_stream_ptr(int index) { + std::shared_lock lk(stream_map_mtx_); + return &get_stream_nolock(index); +} + void Device::new_queue(int index) { auto thread_pool = metal::new_scoped_memory_pool(); auto q = device_->newCommandQueue(); @@ -370,24 +385,16 @@ void Device::new_queue(int index) { throw std::runtime_error( "[metal::Device] Failed to make new command queue."); } - stream_map_.emplace(index, q); + { + std::unique_lock lk(stream_map_mtx_); + stream_map_.emplace(index, std::make_unique(q)); + } if (residency_set_ != nullptr) { q->addResidencySet(residency_set_); } } -MTL::CommandQueue* Device::get_queue(Stream stream) { - return get_stream_(stream.index).queue; -} - -bool Device::command_buffer_needs_commit(int index) { - auto& stream = get_stream_(index); - return (stream.buffer_ops > max_ops_per_buffer_) || - ((stream.buffer_sizes >> 20) > max_mb_per_buffer_); -} - -MTL::CommandBuffer* Device::get_command_buffer(int index) { - auto& stream = get_stream_(index); +MTL::CommandBuffer* Device::ensure_command_buffer(DeviceStream& stream) { if (stream.buffer == nullptr) { stream.buffer = stream.queue->commandBufferWithUnretainedReferences(); if (!stream.buffer) { @@ -400,33 +407,50 @@ MTL::CommandBuffer* Device::get_command_buffer(int index) { return stream.buffer; } +MTL::CommandQueue* Device::get_queue(Stream stream) { + auto* stream_ref = get_stream_ptr(stream.index); + return stream_ref->queue; +} + +bool Device::command_buffer_needs_commit(int index) { + auto* stream = get_stream_ptr(index); + return (stream->buffer_ops > max_ops_per_buffer_) || + ((stream->buffer_sizes >> 20) > max_mb_per_buffer_); +} + +MTL::CommandBuffer* Device::get_command_buffer(int index) { + auto* stream = get_stream_ptr(index); + return ensure_command_buffer(*stream); +} + void Device::commit_command_buffer(int index) { - auto& stream = get_stream_(index); - stream.buffer->commit(); - stream.buffer->release(); - stream.buffer = nullptr; - stream.buffer_ops = 0; - stream.buffer_sizes = 0; + auto* stream = get_stream_ptr(index); + stream->buffer->commit(); + stream->buffer->release(); + stream->buffer = nullptr; + stream->buffer_ops = 0; + stream->buffer_sizes = 0; } void Device::add_temporary(array arr, int index) { - get_stream_(index).temporaries.push_back(std::move(arr)); + auto* stream = get_stream_ptr(index); + stream->temporaries.push_back(std::move(arr)); } void Device::add_temporaries(std::vector arrays, int index) { if (arrays.empty()) { return; } - auto& stream = get_stream_(index); - stream.temporaries.insert( - stream.temporaries.end(), + auto* stream = get_stream_ptr(index); + stream->temporaries.insert( + stream->temporaries.end(), std::make_move_iterator(arrays.begin()), std::make_move_iterator(arrays.end())); } void Device::end_encoding(int index) { - auto& stream = get_stream_(index); - if (stream.encoder != nullptr) { + auto* stream = get_stream_ptr(index); + if (stream->encoder != nullptr) { // Each command encoder has a unique fence. We also store a map of // all previous outputs of command encoders to their corresponding fence. // - The command encoder records its inputs and outputs. @@ -439,9 +463,9 @@ void Device::end_encoding(int index) { // - Temporaries are a special case as they do not cross command encoder // boundaries. These can be removed early from the encoders inputs and // outputs since they don't need synchronization. - auto& enc = *stream.encoder; + auto& enc = *stream->encoder; // Remove temporaries from inputs and outputs - for (auto& t : stream.temporaries) { + for (auto& t : stream->temporaries) { enc.outputs().erase(t.buffer().ptr()); enc.inputs().erase(t.buffer().ptr()); } @@ -450,9 +474,9 @@ void Device::end_encoding(int index) { // in the completion handler so they are not prematurely released std::unordered_set> waiting_on; { - std::lock_guard lk(stream.fence_mtx); + std::lock_guard lk(stream->fence_mtx); for (auto in : enc.inputs()) { - if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { + if (auto it = stream->outputs.find(in); it != stream->outputs.end()) { // If we've already waited on a fence, don't wait on it again. if (waiting_on.find(it->second) == waiting_on.end()) { enc.wait_for_fence(it->second->fence); @@ -461,42 +485,40 @@ void Device::end_encoding(int index) { } } for (auto out : enc.outputs()) { - stream.outputs[out] = stream.fence; + stream->outputs[out] = stream->fence; } } - enc.update_fence(stream.fence->fence); - stream.buffer->addCompletedHandler( - [&stream, + enc.update_fence(stream->fence->fence); + stream->buffer->addCompletedHandler( + [stream, waiting_on = std::move(waiting_on), - fence = std::move(stream.fence), + fence = std::move(stream->fence), outputs = std::move(enc.outputs()), temporaries = - std::move(stream.temporaries)](MTL::CommandBuffer*) mutable { + std::move(stream->temporaries)](MTL::CommandBuffer*) mutable { temporaries.clear(); - std::lock_guard lk(stream.fence_mtx); + std::lock_guard lk(stream->fence_mtx); for (auto o : outputs) { - if (auto it = stream.outputs.find(o); it != stream.outputs.end()) { + if (auto it = stream->outputs.find(o); it != stream->outputs.end()) { if (it->second == fence) { - stream.outputs.erase(it); + stream->outputs.erase(it); } } } }); } - stream.encoder = nullptr; + stream->encoder = nullptr; } CommandEncoder& Device::get_command_encoder(int index) { - auto& stream = get_stream_(index); - if (stream.encoder == nullptr) { + auto* stream = get_stream_ptr(index); + if (stream->encoder == nullptr) { // Ensure there is an active command buffer - if (stream.buffer == nullptr) { - get_command_buffer(index); - } - stream.encoder = std::make_unique(stream); - stream.fence = std::make_shared(device_->newFence()); + ensure_command_buffer(*stream); + stream->encoder = std::make_unique(*stream); + stream->fence = std::make_shared(device_->newFence()); } - return *stream.encoder; + return *stream->encoder; } MTL::Library* Device::get_library( @@ -794,8 +816,9 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) { } residency_set_ = residency_set; // Attach residency set to existing command queues + std::shared_lock lk(stream_map_mtx_); for (auto& [_, stream] : stream_map_) { - stream.queue->addResidencySet(residency_set_); + stream->queue->addResidencySet(residency_set_); } } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 746fbc0888..5b96df9776 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -209,7 +210,7 @@ class Device { private: DeviceStream& get_stream_(int index) { - return stream_map_.find(index)->second; + return *stream_map_.find(index)->second; } MTL::Library* get_library_cache_(const std::string& name); @@ -243,8 +244,13 @@ class Device { const MTLFCList& func_consts = {}, const std::vector& linked_functions = {}); + DeviceStream& get_stream_nolock(int index); + DeviceStream* get_stream_ptr(int index); + MTL::CommandBuffer* ensure_command_buffer(DeviceStream& stream); + MTL::Device* device_; - std::unordered_map stream_map_; + mutable std::shared_mutex stream_map_mtx_; + std::unordered_map> stream_map_; std::shared_mutex kernel_mtx_; std::shared_mutex library_mtx_; diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp index b19f6434a1..423499ee5d 100644 --- a/mlx/scheduler.cpp +++ b/mlx/scheduler.cpp @@ -55,6 +55,11 @@ void synchronize() { namespace scheduler { +std::mutex& stream_creation_mutex() { + static std::mutex mutex; + return mutex; +} + /** A singleton scheduler to manage devices, streams, and task execution. */ Scheduler& scheduler() { // Leak the scheduler on Windows to avoid joining threads on exit, can be diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 877fdd5f6a..1975849838 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -4,6 +4,8 @@ #include #include +#include +#include #include #include #include @@ -14,6 +16,8 @@ namespace mlx::core::scheduler { +std::mutex& stream_creation_mutex(); + struct StreamThread { std::mutex mtx; std::queue> q; @@ -79,14 +83,27 @@ class Scheduler { Scheduler& operator=(Scheduler&&) = delete; Stream new_stream(const Device& d) { - streams_.emplace_back(streams_.size(), d); + std::unique_ptr stream_thread; + if (d != Device::gpu) { + stream_thread = std::make_unique(); + } + + Stream stream; + { + std::lock_guard lk(stream_creation_mutex()); + streams_.emplace_back(streams_.size(), d); + stream = streams_.back(); + if (d == Device::gpu) { + threads_.push_back(nullptr); + } else { + threads_.push_back(stream_thread.release()); + } + } + if (d == Device::gpu) { - threads_.push_back(nullptr); - gpu::new_stream(streams_.back()); - } else { - threads_.push_back(new StreamThread{}); + gpu::new_stream(stream); } - return streams_.back(); + return stream; } template diff --git a/tests/scheduler_tests.cpp b/tests/scheduler_tests.cpp index bbf97d05c4..b1df38417b 100644 --- a/tests/scheduler_tests.cpp +++ b/tests/scheduler_tests.cpp @@ -2,6 +2,12 @@ #include "doctest/doctest.h" +#include +#include +#include +#include +#include + #include "mlx/mlx.h" #include "mlx/scheduler.h" @@ -107,6 +113,36 @@ TEST_CASE("test stream placement") { } } +TEST_CASE("test concurrent stream creation") { + constexpr int kNumThreads = 16; + std::promise go; + auto start = go.get_future().share(); + + std::mutex results_mtx; + std::vector indices; + indices.reserve(kNumThreads); + std::vector threads; + threads.reserve(kNumThreads); + + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&]() { + start.wait(); + auto s = new_stream(Device::cpu); + std::lock_guard lk(results_mtx); + indices.push_back(s.index); + }); + } + + go.set_value(); + for (auto& t : threads) { + t.join(); + } + + CHECK_EQ(indices.size(), static_cast(kNumThreads)); + std::unordered_set unique_indices(indices.begin(), indices.end()); + CHECK_EQ(unique_indices.size(), indices.size()); +} + TEST_CASE("test scheduler races") { auto x = zeros({1}); auto y = zeros({100}); From 2767b2644d9fb8a384ba575d3ab2dd05dc5425e5 Mon Sep 17 00:00:00 2001 From: Jack Date: Thu, 20 Nov 2025 10:53:37 -0500 Subject: [PATCH 3/4] . --- mlx/scheduler.h | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 1975849838..64fe413b7a 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -2,8 +2,6 @@ #pragma once -#include -#include #include #include #include @@ -83,25 +81,20 @@ class Scheduler { Scheduler& operator=(Scheduler&&) = delete; Stream new_stream(const Device& d) { - std::unique_ptr stream_thread; - if (d != Device::gpu) { - stream_thread = std::make_unique(); - } - - Stream stream; - { - std::lock_guard lk(stream_creation_mutex()); - streams_.emplace_back(streams_.size(), d); - stream = streams_.back(); - if (d == Device::gpu) { - threads_.push_back(nullptr); - } else { - threads_.push_back(stream_thread.release()); - } - } - + // Lock the mutex to ensure that the stream is created in a thread-safe manner + // This is necessary because the stream creation is not thread-safe + std::lock_guard lk(stream_creation_mutex()); + const auto new_stream_index = static_cast(streams_.size()); + Stream stream(new_stream_index, d); + streams_.push_back(stream); + + // Create the stream (GPU) or thread (CPU) if (d == Device::gpu) { + threads_.push_back(nullptr); gpu::new_stream(stream); + } else { + auto stream_thread = std::make_unique(); + threads_.push_back(stream_thread.release()); } return stream; } From ba63a3bf35a9438319a8a2f00dc85eba8989bfcd Mon Sep 17 00:00:00 2001 From: Jack Date: Thu, 20 Nov 2025 11:17:27 -0500 Subject: [PATCH 4/4] formatting --- mlx/backend/metal/device.cpp | 35 +++++++++++++++++------------------ mlx/scheduler.h | 4 ++-- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index f4dec76ea7..90a6331441 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -366,8 +366,7 @@ Device::~Device() { DeviceStream& Device::get_stream_nolock(int index) { auto it = stream_map_.find(index); if (it == stream_map_.end()) { - throw std::out_of_range( - "[metal::Device] Invalid stream index requested."); + throw std::out_of_range("[metal::Device] Invalid stream index requested."); } return *it->second; } @@ -489,23 +488,23 @@ void Device::end_encoding(int index) { } } enc.update_fence(stream->fence->fence); - stream->buffer->addCompletedHandler( - [stream, - waiting_on = std::move(waiting_on), - fence = std::move(stream->fence), - outputs = std::move(enc.outputs()), - temporaries = - std::move(stream->temporaries)](MTL::CommandBuffer*) mutable { - temporaries.clear(); - std::lock_guard lk(stream->fence_mtx); - for (auto o : outputs) { - if (auto it = stream->outputs.find(o); it != stream->outputs.end()) { - if (it->second == fence) { - stream->outputs.erase(it); - } - } + stream->buffer->addCompletedHandler([stream, + waiting_on = std::move(waiting_on), + fence = std::move(stream->fence), + outputs = std::move(enc.outputs()), + temporaries = + std::move(stream->temporaries)]( + MTL::CommandBuffer*) mutable { + temporaries.clear(); + std::lock_guard lk(stream->fence_mtx); + for (auto o : outputs) { + if (auto it = stream->outputs.find(o); it != stream->outputs.end()) { + if (it->second == fence) { + stream->outputs.erase(it); } - }); + } + } + }); } stream->encoder = nullptr; } diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 64fe413b7a..b1bfdcedea 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -81,8 +81,8 @@ class Scheduler { Scheduler& operator=(Scheduler&&) = delete; Stream new_stream(const Device& d) { - // Lock the mutex to ensure that the stream is created in a thread-safe manner - // This is necessary because the stream creation is not thread-safe + // Lock the mutex to ensure that the stream is created in a thread-safe + // manner This is necessary because the stream creation is not thread-safe std::lock_guard lk(stream_creation_mutex()); const auto new_stream_index = static_cast(streams_.size()); Stream stream(new_stream_index, d);