@@ -28,6 +28,7 @@ limitations under the License.
2828
2929#include " absl/base/thread_annotations.h"
3030#include " absl/cleanup/cleanup.h"
31+ #include " absl/hash/hash.h"
3132#include " absl/strings/str_cat.h"
3233#include " absl/synchronization/mutex.h"
3334#include " absl/synchronization/notification.h"
@@ -78,34 +79,58 @@ class HashablePyDictIter {
7879 nb::detail::dict_iterator& iter_;
7980};
8081
82+ struct HashableKey {
83+ nb::object context;
84+ nb::args args;
85+ nb::kwargs kwargs;
86+
87+ template <typename H>
88+ friend H AbslHashValue (H h, const HashableKey& key) {
89+ // Note: Despite the fact this is an ABSL hash function, it's safe to call
90+ // functions that may throw exceptions such as nb::hash(), because it is
91+ // used by an LRUCache, which uses a std::unordered_map, which is
92+ // exception-safe.
93+ h = H::combine (std::move (h), nb::hash (key.context ), nb::hash (key.args ));
94+ nb::detail::dict_iterator begin = key.kwargs .begin ();
95+ nb::detail::dict_iterator end = key.kwargs .end ();
96+ h = H::combine_unordered (std::move (h), HashablePyDictIter (begin),
97+ HashablePyDictIter (end));
98+ h = H::combine (std::move (h), key.kwargs .size ());
99+ return h;
100+ }
101+ };
102+
81103} // namespace
82104
83105class WeakrefLRUCache : public std ::enable_shared_from_this<WeakrefLRUCache> {
84106 public:
85- struct Key {
86- nb::object context;
87- nb::args args;
88- nb::kwargs kwargs;
107+ class Key {
108+ public:
109+ Key (nb::object context, nb::args args, nb::kwargs kwargs)
110+ : context_(std::move(context)),
111+ args_ (std::move(args)),
112+ kwargs_(std::move(kwargs)),
113+ cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {}
89114
90115 bool operator ==(const Key& other) const {
91- return context .equal (other.context ) && args .equal (other.args ) &&
92- kwargs .equal (other.kwargs );
116+ return context_ .equal (other.context_ ) && args_ .equal (other.args_ ) &&
117+ kwargs_ .equal (other.kwargs_ );
93118 }
94119
95120 template <typename H>
96121 friend H AbslHashValue (H h, const Key& key) {
97- // Note: Despite the fact this is an ABSL hash function, it's safe to call
98- // functions that may throw exceptions such as nb::hash(), because it is
99- // used by an LRUCache, which uses a std::unordered_map, which is
100- // exception-safe.
101- h = H::combine (std::move (h), nb::hash (key.context ), nb::hash (key.args ));
102- nb::detail::dict_iterator begin = key.kwargs .begin ();
103- nb::detail::dict_iterator end = key.kwargs .end ();
104- h = H::combine_unordered (std::move (h), HashablePyDictIter (begin),
105- HashablePyDictIter (end));
106- h = H::combine (std::move (h), key.kwargs .size ());
107- return h;
122+ return H::combine (std::move (h), key.cached_hash_ );
108123 }
124+
125+ nb::object context () const { return context_; }
126+ nb::args args () const { return args_; }
127+ nb::kwargs kwargs () const { return kwargs_; }
128+
129+ private:
130+ nb::object context_;
131+ nb::args args_;
132+ nb::kwargs kwargs_;
133+ size_t cached_hash_;
109134 };
110135
111136 struct CacheEntry {
@@ -123,14 +148,13 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
123148 };
124149
125150 struct WeakrefCacheKey {
126- nb::handle object ;
151+ nb::weakref ref ;
127152 size_t cached_hash;
128153 };
129154
130155 using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;
131156
132157 struct WeakrefCacheValue {
133- std::optional<nb::weakref> weakref;
134158 std::shared_ptr<Cache> cache;
135159 };
136160
@@ -141,7 +165,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
141165 struct WeakrefKeyEq {
142166 bool operator ()(const WeakrefCacheKey& lhs,
143167 const WeakrefCacheKey& rhs) const {
144- return lhs.object .equal (rhs.object );
168+ return lhs.ref .equal (rhs.ref );
145169 }
146170 };
147171
@@ -150,43 +174,49 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
150174 : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}
151175
152176 std::shared_ptr<Cache> GetCache (WeakrefCacheKey key) {
153- auto [it, inserted] = entries_. emplace ( key, WeakrefCacheValue ()) ;
154- if (!inserted ) {
155- return it-> second .cache ;
177+ WeakrefCacheValue& value = entries_[ key] ;
178+ if (!value. cache ) {
179+ value .cache = std::make_shared<Cache>(&lru_list_) ;
156180 }
181+ return value.cache ;
182+ }
157183
158- auto & value = it->second ;
184+ nb::object Call (nb::object weakref_key, nb::args args,
185+ nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
186+ nb::object context = cache_context_fn_ ();
187+
188+ // We precompute all of the hash values needed by the various maps rather
189+ // than computing them during the std::unordered_map insertions. At the very
190+ // least, MSVC's std::unordered_map has undefined behavior if the hash
191+ // function throws an exception
192+ // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace).
193+ Key key (context, args, kwargs);
194+ size_t wrcache_hash = static_cast <size_t >(nb::hash (weakref_key));
195+
196+ // No hash computations after this point.
159197
160- value.cache = std::make_shared<Cache>(&lru_list_);
161198 auto weakref_gc_callback = nb::cpp_function (
162- [this_weak = weak_from_this (), key ](nb::handle weakref) {
199+ [this_weak = weak_from_this (), wrcache_hash ](nb::handle weakref) {
163200 auto cache = this_weak.lock ();
164201 if (cache == nullptr ) {
165202 return ;
166203 }
167- auto it = cache->entries_ .find (key);
204+ // The object the reference referred to is now in the process of being
205+ // destroyed, so we cannot refer to its contents. Python weakref
206+ // objects compare based on identity if the object they refer to is
207+ // gone, so the hash lookup will work fine.
208+ auto it = cache->entries_ .find (
209+ WeakrefCacheKey{nb::borrow<nb::weakref>(weakref), wrcache_hash});
168210 if (it == cache->entries_ .end ()) {
169211 return ;
170212 }
171213 // Create temp-var to avoid re-entrant erase.
172214 auto tmp = std::move (it->second );
173215 cache->entries_ .erase (it);
174216 });
175- PyObject* ref =
176- PyWeakref_NewRef (key.object .ptr (), weakref_gc_callback.ptr ());
177- if (!ref) {
178- entries_.erase (it);
179- throw nb::python_error ();
180- }
181- value.weakref = nb::steal<nb::weakref>(ref);
182- return value.cache ;
183- }
184-
185- nb::object Call (nb::object weakref_key, nb::args args,
186- nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
187- nb::object context = cache_context_fn_ ();
188- std::shared_ptr<Cache> cache_ptr = GetCache (WeakrefCacheKey{
189- weakref_key, static_cast <size_t >(nb::hash (weakref_key))});
217+ nb::weakref weakref = nb::weakref (weakref_key, weakref_gc_callback);
218+ WeakrefCacheKey wrcache_key{weakref, wrcache_hash};
219+ std::shared_ptr<Cache> cache_ptr = GetCache (wrcache_key);
190220 Cache& cache = *cache_ptr;
191221 ++total_queries_;
192222
@@ -206,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
206236 // released if that happens.
207237 absl::Cleanup unlock = [this ]()
208238 ABSL_UNLOCK_FUNCTION (mu_) { mu_.Unlock (); };
209- Key key{context, args, kwargs};
210239 entry = cache.GetOrCreateIfAbsent (key, [&inserted](const Key& key) {
211240 inserted = true ;
212241 return std::make_shared<CacheEntry>();
@@ -245,8 +274,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
245274 for (const auto & wr_entry : entries_) {
246275 for (const auto & rest : *wr_entry.second .cache ) {
247276 nb::tuple result =
248- nb::make_tuple (*wr_entry.second . weakref , rest.first .context ,
249- rest.first .args , rest.first .kwargs );
277+ nb::make_tuple (*wr_entry.first . ref , rest.first .context () ,
278+ rest.first .args () , rest.first .kwargs () );
250279 results.push_back (std::move (result));
251280 }
252281 }
0 commit comments