1919import weakref
2020
2121from absl .testing import absltest
22-
23- from jax .jaxlib .xla import xla_client
22+ from jax .jaxlib import weakref_lru_cache
2423
2524
2625class WeakrefLRUCacheTest (absltest .TestCase ):
@@ -60,7 +59,7 @@ def CacheFn(obj, gil_releasing_cache_key):
6059 del gil_releasing_cache_key
6160 return None
6261
63- cache = xla_client .weakref_lru_cache (lambda : None , CacheFn , 2048 )
62+ cache = weakref_lru_cache .weakref_lru_cache (lambda : None , CacheFn , 2048 )
6463
6564 wrkey = WRKey ()
6665
@@ -79,7 +78,9 @@ def Body():
7978 def testAnotherMultiThreaded (self ):
8079 num_workers = 5
8180 barrier = threading .Barrier (num_workers )
82- cache = xla_client .weakref_lru_cache (lambda : None , lambda x , y : y , 2048 )
81+ cache = weakref_lru_cache .weakref_lru_cache (
82+ lambda : None , lambda x , y : y , 2048
83+ )
8384
8485 class WRKey :
8586 pass
@@ -118,7 +119,7 @@ def CacheFn(obj, kwkey1, kwkey2):
118119 miss_id += 1
119120 return miss_id
120121
121- cache = xla_client .weakref_lru_cache (lambda : None , CacheFn , 4 )
122+ cache = weakref_lru_cache .weakref_lru_cache (lambda : None , CacheFn , 4 )
122123
123124 wrkey = WRKey ()
124125
@@ -131,7 +132,7 @@ def CacheFn(obj, arg):
131132 del obj
132133 return arg + "extra"
133134
134- cache = xla_client .weakref_lru_cache (lambda : None , CacheFn , 4 )
135+ cache = weakref_lru_cache .weakref_lru_cache (lambda : None , CacheFn , 4 )
135136
136137 class WRKey :
137138 pass
@@ -151,7 +152,7 @@ class NonWRKey:
151152 with self .assertRaises (TypeError ):
152153 weakref .ref (non_wr_key )
153154
154- cache = xla_client .weakref_lru_cache (lambda : None , lambda x : 2048 )
155+ cache = weakref_lru_cache .weakref_lru_cache (lambda : None , lambda x : 2048 )
155156 for _ in range (100 ):
156157 with self .assertRaises (TypeError ):
157158 cache (non_wr_key )
@@ -169,7 +170,9 @@ def __eq__(self, other):
169170 def __hash__ (self ):
170171 raise ValueError ("hash" )
171172
172- cache = xla_client .weakref_lru_cache (lambda : None , lambda x , y : y , 2048 )
173+ cache = weakref_lru_cache .weakref_lru_cache (
174+ lambda : None , lambda x , y : y , 2048
175+ )
173176 wrkey = WRKey ()
174177 with self .assertRaises (ValueError ):
175178 for _ in range (100 ):
@@ -179,7 +182,9 @@ def testPrintingStats(self):
179182 class WRKey :
180183 pass
181184
182- cache = xla_client .weakref_lru_cache (lambda : None , lambda x , y : y , 2048 )
185+ cache = weakref_lru_cache .weakref_lru_cache (
186+ lambda : None , lambda x , y : y , 2048
187+ )
183188 wrkey = WRKey ()
184189 for i in range (10 ):
185190 cache (wrkey , i )
@@ -203,7 +208,9 @@ def __eq__(self, other):
203208 def __hash__ (self ):
204209 return hash (self .x )
205210
206- cache = xla_client .weakref_lru_cache (lambda : None , lambda x , y : y , 2048 )
211+ cache = weakref_lru_cache .weakref_lru_cache (
212+ lambda : None , lambda x , y : y , 2048
213+ )
207214 keys = [WRKey (i ) for i in range (10 )]
208215 for i in range (10 ):
209216 cache (keys [i ], i )
@@ -225,7 +232,7 @@ def CallFn(x, y, *args, **kwargs):
225232 del x , args , kwargs
226233 return y
227234
228- cache = xla_client .weakref_lru_cache (CacheContextFn , CallFn , 2048 )
235+ cache = weakref_lru_cache .weakref_lru_cache (CacheContextFn , CallFn , 2048 )
229236
230237 keys = [WRKey () for _ in range (10 )]
231238 values = [str (i ) for i in range (10 )]
@@ -239,7 +246,7 @@ def CallFn(x, y, *args, **kwargs):
239246 [
240247 CacheContextFn ,
241248 CallFn ,
242- xla_client . _xla .WeakrefLRUCache ,
249+ weakref_lru_cache .WeakrefLRUCache ,
243250 kwargs ,
244251 ]
245252 + [weakref .getweakrefs (key )[0 ] for key in keys ]
0 commit comments