diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 91743ec04b..fbe83e9b85 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include #include +#include #include #include #include @@ -305,6 +306,7 @@ class CompilerCache { const std::vector& inputs, bool shapeless, const std::vector& constants) { + std::lock_guard guard(mtx_); // Find the cache entries for |fun_id|. std::vector& entries = cache_[fun_id]; @@ -353,10 +355,12 @@ class CompilerCache { } void erase(std::uintptr_t fun_id) { + std::lock_guard guard(mtx_); cache_.erase(fun_id); } void clear() { + std::lock_guard guard(mtx_); cache_.clear(); } @@ -368,6 +372,7 @@ class CompilerCache { } friend CompilerCache& compiler_cache(); + std::mutex mtx_; std::unordered_map> cache_; };