Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 85 additions & 59 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdlib>
#include <sstream>
#include <stdexcept>

#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
Expand Down Expand Up @@ -362,6 +363,19 @@ Device::~Device() {
device_->release();
}

DeviceStream& Device::get_stream_nolock(int index) {
auto it = stream_map_.find(index);
if (it == stream_map_.end()) {
throw std::out_of_range("[metal::Device] Invalid stream index requested.");
}
return *it->second;
}

DeviceStream* Device::get_stream_ptr(int index) {
std::shared_lock<std::shared_mutex> lk(stream_map_mtx_);
return &get_stream_nolock(index);
}

void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool();
auto q = device_->newCommandQueue();
Expand All @@ -370,24 +384,16 @@ void Device::new_queue(int index) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
stream_map_.emplace(index, q);
{
std::unique_lock<std::shared_mutex> lk(stream_map_mtx_);
stream_map_.emplace(index, std::make_unique<DeviceStream>(q));
}
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}

MTL::CommandQueue* Device::get_queue(Stream stream) {
return get_stream_(stream.index).queue;
}

bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
return (stream.buffer_ops > max_ops_per_buffer_) ||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto& stream = get_stream_(index);
MTL::CommandBuffer* Device::ensure_command_buffer(DeviceStream& stream) {
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!stream.buffer) {
Expand All @@ -400,33 +406,50 @@ MTL::CommandBuffer* Device::get_command_buffer(int index) {
return stream.buffer;
}

MTL::CommandQueue* Device::get_queue(Stream stream) {
auto* stream_ref = get_stream_ptr(stream.index);
return stream_ref->queue;
}

bool Device::command_buffer_needs_commit(int index) {
auto* stream = get_stream_ptr(index);
return (stream->buffer_ops > max_ops_per_buffer_) ||
((stream->buffer_sizes >> 20) > max_mb_per_buffer_);
}

MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto* stream = get_stream_ptr(index);
return ensure_command_buffer(*stream);
}

void Device::commit_command_buffer(int index) {
auto& stream = get_stream_(index);
stream.buffer->commit();
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
stream.buffer_sizes = 0;
auto* stream = get_stream_ptr(index);
stream->buffer->commit();
stream->buffer->release();
stream->buffer = nullptr;
stream->buffer_ops = 0;
stream->buffer_sizes = 0;
}

void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
auto* stream = get_stream_ptr(index);
stream->temporaries.push_back(std::move(arr));
}

void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
auto* stream = get_stream_ptr(index);
stream->temporaries.insert(
stream->temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}

void Device::end_encoding(int index) {
auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
auto* stream = get_stream_ptr(index);
if (stream->encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
Expand All @@ -439,9 +462,9 @@ void Device::end_encoding(int index) {
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
auto& enc = *stream->encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
for (auto& t : stream->temporaries) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
Expand All @@ -450,9 +473,9 @@ void Device::end_encoding(int index) {
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
std::lock_guard<std::mutex> lk(stream->fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
if (auto it = stream->outputs.find(in); it != stream->outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc.wait_for_fence(it->second->fence);
Expand All @@ -461,42 +484,40 @@ void Device::end_encoding(int index) {
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
stream->outputs[out] = stream->fence;
}
}
enc.update_fence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
enc.update_fence(stream->fence->fence);
stream->buffer->addCompletedHandler([stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream->fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream->temporaries)](
MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream->fence_mtx);
for (auto o : outputs) {
if (auto it = stream->outputs.find(o); it != stream->outputs.end()) {
if (it->second == fence) {
stream->outputs.erase(it);
}
});
}
}
});
}
stream.encoder = nullptr;
stream->encoder = nullptr;
}

CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
auto* stream = get_stream_ptr(index);
if (stream->encoder == nullptr) {
// Ensure there is an active command buffer
if (stream.buffer == nullptr) {
get_command_buffer(index);
}
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.fence = std::make_shared<Fence>(device_->newFence());
ensure_command_buffer(*stream);
stream->encoder = std::make_unique<CommandEncoder>(*stream);
stream->fence = std::make_shared<Fence>(device_->newFence());
}
return *stream.encoder;
return *stream->encoder;
}

MTL::Library* Device::get_library(
Expand Down Expand Up @@ -742,8 +763,12 @@ MTL::ComputePipelineState* Device::get_kernel_(
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);

mtl_function->release();
mtl_linked_funcs->release();
if (mtl_function) {
mtl_function->release();
}
if (mtl_linked_funcs) {
mtl_linked_funcs->release();
}

// Add kernel to cache
kernel_map_.insert({hash_name, kernel});
Expand Down Expand Up @@ -790,8 +815,9 @@ void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
std::shared_lock<std::shared_mutex> lk(stream_map_mtx_);
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
stream->queue->addResidencySet(residency_set_);
}
}

Expand Down
10 changes: 8 additions & 2 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <Metal/Metal.hpp>
#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <string>
Expand Down Expand Up @@ -209,7 +210,7 @@ class Device {

private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
return *stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name);

Expand Down Expand Up @@ -243,8 +244,13 @@ class Device {
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});

DeviceStream& get_stream_nolock(int index);
DeviceStream* get_stream_ptr(int index);
MTL::CommandBuffer* ensure_command_buffer(DeviceStream& stream);

MTL::Device* device_;
std::unordered_map<int32_t, DeviceStream> stream_map_;
mutable std::shared_mutex stream_map_mtx_;
std::unordered_map<int32_t, std::unique_ptr<DeviceStream>> stream_map_;

std::shared_mutex kernel_mtx_;
std::shared_mutex library_mtx_;
Expand Down
5 changes: 5 additions & 0 deletions mlx/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ void synchronize() {

namespace scheduler {

std::mutex& stream_creation_mutex() {
static std::mutex mutex;
return mutex;
}

/** A singleton scheduler to manage devices, streams, and task execution. */
Scheduler& scheduler() {
// Leak the scheduler on Windows to avoid joining threads on exit, can be
Expand Down
22 changes: 16 additions & 6 deletions mlx/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#pragma once

#include <atomic>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <unordered_map>
Expand All @@ -14,6 +14,8 @@

namespace mlx::core::scheduler {

std::mutex& stream_creation_mutex();

struct StreamThread {
std::mutex mtx;
std::queue<std::function<void()>> q;
Expand Down Expand Up @@ -79,14 +81,22 @@ class Scheduler {
Scheduler& operator=(Scheduler&&) = delete;

Stream new_stream(const Device& d) {
streams_.emplace_back(streams_.size(), d);
// Lock the mutex to ensure that the stream is created in a thread-safe
// manner This is necessary because the stream creation is not thread-safe
std::lock_guard<std::mutex> lk(stream_creation_mutex());
const auto new_stream_index = static_cast<int>(streams_.size());
Stream stream(new_stream_index, d);
streams_.push_back(stream);

// Create the stream (GPU) or thread (CPU)
if (d == Device::gpu) {
threads_.push_back(nullptr);
gpu::new_stream(streams_.back());
gpu::new_stream(stream);
} else {
threads_.push_back(new StreamThread{});
auto stream_thread = std::make_unique<StreamThread>();
threads_.push_back(stream_thread.release());
}
return streams_.back();
return stream;
}

template <typename F>
Expand Down
36 changes: 36 additions & 0 deletions tests/scheduler_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

#include "doctest/doctest.h"

#include <future>
#include <mutex>
#include <thread>
#include <unordered_set>
#include <vector>

#include "mlx/mlx.h"
#include "mlx/scheduler.h"

Expand Down Expand Up @@ -107,6 +113,36 @@ TEST_CASE("test stream placement") {
}
}

TEST_CASE("test concurrent stream creation") {
constexpr int kNumThreads = 16;
std::promise<void> go;
auto start = go.get_future().share();

std::mutex results_mtx;
std::vector<int> indices;
indices.reserve(kNumThreads);
std::vector<std::thread> threads;
threads.reserve(kNumThreads);

for (int i = 0; i < kNumThreads; ++i) {
threads.emplace_back([&]() {
start.wait();
auto s = new_stream(Device::cpu);
std::lock_guard<std::mutex> lk(results_mtx);
indices.push_back(s.index);
});
}

go.set_value();
for (auto& t : threads) {
t.join();
}

CHECK_EQ(indices.size(), static_cast<std::size_t>(kNumThreads));
std::unordered_set<int> unique_indices(indices.begin(), indices.end());
CHECK_EQ(unique_indices.size(), indices.size());
}

TEST_CASE("test scheduler races") {
auto x = zeros({1});
auto y = zeros({100});
Expand Down