Skip to content

Commit b4629c2

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Split weakref_lru_cache into its own extension.
Now that jax-ml@db11efa has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA. There's no reason weakref_lru_cache is in the same Python extension as everything else. PiperOrigin-RevId: 745271825
1 parent 2d44f98 commit b4629c2

File tree

13 files changed

+112
-102
lines changed

13 files changed

+112
-102
lines changed

jax/_src/lib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ py_library_providing_imports_info(
4444
"//jaxlib/mosaic/python:tpu_dialect",
4545
"//jaxlib:cpu_feature_guard",
4646
"//jaxlib:utils",
47+
"//jaxlib:weakref_lru_cache",
4748
"//jaxlib/xla:xla_client",
4849
"//jaxlib/xla:xla_extension",
4950
"//jaxlib/triton",

jax/_src/lib/__init__.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,23 @@ def _parse_version(v: str) -> tuple[int, ...]:
9090
from jaxlib.xla_extension import jax_jit as jax_jit # noqa: F401
9191
from jaxlib.xla_extension import pmap_lib as pmap_lib # noqa: F401
9292
from jaxlib.xla_extension import pytree as pytree # noqa: F401
93-
import jaxlib.xla_client as xla_client # noqa: F401
9493

9594
from jaxlib.xla_extension import Device as Device # noqa: F401
9695

96+
import jaxlib.xla_client as xla_client # noqa: F401
97+
98+
# Jaxlib code is split between the Jax and the XLA repositories.
99+
# Only for the internal usage of the JAX developers, we expose a version
100+
# number that can be used to perform changes without breaking the main
101+
# branch on the Jax github.
102+
jaxlib_extension_version: int = getattr(xla_client, '_version', 0)
103+
ifrt_version: int = getattr(xla_client, '_ifrt_version', 0)
104+
105+
# TODO(phawkins): remove type: ignore once the minimum jaxlib is bumped.
106+
if jaxlib_extension_version >= 328:
107+
import jaxlib.weakref_lru_cache as weakref_lru_cache # type: ignore # noqa: F401
108+
else:
109+
weakref_lru_cache = xla_extension # type: ignore # noqa: F401
97110

98111
# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882
99112
def _xla_gc_callback(*args):
@@ -113,13 +126,6 @@ def _xla_gc_callback(*args):
113126
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401
114127
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401
115128

116-
# Jaxlib code is split between the Jax and the Tensorflow repositories.
117-
# Only for the internal usage of the JAX developers, we expose a version
118-
# number that can be used to perform changes without breaking the main
119-
# branch on the Jax github.
120-
jaxlib_extension_version: int = getattr(xla_client, '_version', 0)
121-
ifrt_version: int = getattr(xla_client, '_ifrt_version', 0)
122-
123129
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
124130
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401
125131

jax/_src/util.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import numpy as np
2828

2929
from jax._src import config
30-
from jax._src.lib import xla_client as xc
30+
from jax._src.lib import weakref_lru_cache as _weakref_lru_cache
3131
from jax._src.lib import utils as jaxlib_utils
3232

3333
logger = logging.getLogger(__name__)
@@ -304,8 +304,9 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
304304
behave similar to `functools.lru_cache`.
305305
"""
306306
global _weakref_lru_caches
307-
cached_call = xc.weakref_lru_cache(
308-
config.trace_context if trace_context_in_key else _ignore, call, maxsize)
307+
cached_call = _weakref_lru_cache.weakref_lru_cache(
308+
config.trace_context if trace_context_in_key else _ignore, call, maxsize
309+
)
309310
_weakref_lru_caches.add(cached_call)
310311
return cached_call
311312

jaxlib/BUILD

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
load(
1818
"//jaxlib:jax.bzl",
1919
"nanobind_extension",
20+
"py_deps",
2021
"py_library_providing_imports_info",
22+
"py_strict_test",
2123
"pytype_library",
2224
)
2325
load(
@@ -59,6 +61,7 @@ py_library_providing_imports_info(
5961
":cpu_feature_guard",
6062
":jax",
6163
":utils",
64+
":weakref_lru_cache",
6265
"//jaxlib/cpu:_lapack",
6366
"//jaxlib/mlir",
6467
"//jaxlib/mlir:arithmetic_dialect",
@@ -119,6 +122,7 @@ pywrap_library(
119122
},
120123
deps = [
121124
":utils",
125+
":weakref_lru_cache",
122126
"//jaxlib/mlir/_mlir_libs:_chlo",
123127
"//jaxlib/mlir/_mlir_libs:_mlir",
124128
"//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU",
@@ -215,6 +219,35 @@ nanobind_extension(
215219
],
216220
)
217221

222+
nanobind_pywrap_extension(
223+
name = "weakref_lru_cache",
224+
srcs = ["weakref_lru_cache.cc"],
225+
pytype_srcs = ["weakref_lru_cache.pyi"],
226+
deps = [
227+
"@com_google_absl//absl/base:core_headers",
228+
"@com_google_absl//absl/cleanup",
229+
"@com_google_absl//absl/hash",
230+
"@com_google_absl//absl/strings",
231+
"@com_google_absl//absl/synchronization",
232+
"@nanobind",
233+
"@xla//third_party/python_runtime:headers",
234+
"@xla//xla/pjrt:lru_cache",
235+
"@xla//xla/tsl/platform:logging",
236+
],
237+
)
238+
239+
py_strict_test(
240+
name = "weakref_lru_cache_test",
241+
srcs = ["weakref_lru_cache_test.py"],
242+
deps = [
243+
":weakref_lru_cache",
244+
] + py_deps([
245+
"absl/flags",
246+
"absl/logging",
247+
"absl/testing",
248+
]),
249+
)
250+
218251
nanobind_pywrap_extension(
219252
name = "utils",
220253
srcs = ["utils.cc"],

jaxlib/tools/build_wheel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def patch_copy_mlir_import(src_file, dst_dir):
103103
"pytree.pyi",
104104
"transfer_guard_lib.pyi",
105105
]
106-
_OPTIONAL_XLA_EXTENSION_STUBS = []
107106

108107

109108
def patch_copy_xla_extension_stubs(dst_dir):
@@ -112,8 +111,6 @@ def patch_copy_xla_extension_stubs(dst_dir):
112111
for stub_name in _XLA_EXTENSION_STUBS:
113112
stub_path = r.Rlocation("__main__/jaxlib/xla/xla_extension/" + stub_name)
114113
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
115-
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
116-
continue
117114
with open(stub_path) as f:
118115
src = f.read()
119116
src = src.replace(
@@ -199,6 +196,8 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu):
199196
"__main__/jaxlib/plugin_support.py",
200197
"__main__/jaxlib/version.py",
201198
"__main__/jaxlib/xla/xla_client.py",
199+
f"__main__/jaxlib/weakref_lru_cache.{pyext}",
200+
"__main__/jaxlib/weakref_lru_cache.pyi",
202201
f"__main__/jaxlib/xla_extension.{pyext}",
203202
],
204203
)
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "jaxlib/xla/weakref_lru_cache.h"
17-
1816
#include <Python.h>
1917

2018
#include <cstddef>
@@ -369,7 +367,7 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
369367
{0, nullptr},
370368
};
371369

372-
void BuildWeakrefLRUCacheAPI(nb::module_& m) {
370+
NB_MODULE(weakref_lru_cache, m) {
373371
auto weakref_lru_cache =
374372
nb::class_<WeakrefLRUCache>(m, "WeakrefLRUCache",
375373
nb::is_weak_referenceable(),

jaxlib/weakref_lru_cache.pyi

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 The JAX Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from collections.abc import Callable
17+
18+
class WeakrefLRUCache:
19+
def __call__(self, arg0: object, /, *args, **kwargs) -> object: ...
20+
def cache_keys(self) -> list[object]: ...
21+
def cache_info(self) -> WeakrefLRUCache.WeakrefLRUCacheInfo: ...
22+
def cache_clear(self) -> None: ...
23+
24+
class WeakrefLRUCacheInfo:
25+
@property
26+
def hits(self) -> int: ...
27+
@property
28+
def misses(self) -> int: ...
29+
@property
30+
def maxsize(self) -> int: ...
31+
@property
32+
def currsize(self) -> int: ...
33+
def __repr__(self) -> str: ...
34+
35+
def weakref_lru_cache(
36+
cache_context_fn: Callable, fn: Callable, maxsize: int = 2048
37+
) -> WeakrefLRUCache: ...
Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
import weakref
2020

2121
from absl.testing import absltest
22-
23-
from jax.jaxlib.xla import xla_client
22+
from jax.jaxlib import weakref_lru_cache
2423

2524

2625
class WeakrefLRUCacheTest(absltest.TestCase):
@@ -60,7 +59,7 @@ def CacheFn(obj, gil_releasing_cache_key):
6059
del gil_releasing_cache_key
6160
return None
6261

63-
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 2048)
62+
cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 2048)
6463

6564
wrkey = WRKey()
6665

@@ -79,7 +78,9 @@ def Body():
7978
def testAnotherMultiThreaded(self):
8079
num_workers = 5
8180
barrier = threading.Barrier(num_workers)
82-
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
81+
cache = weakref_lru_cache.weakref_lru_cache(
82+
lambda: None, lambda x, y: y, 2048
83+
)
8384

8485
class WRKey:
8586
pass
@@ -118,7 +119,7 @@ def CacheFn(obj, kwkey1, kwkey2):
118119
miss_id += 1
119120
return miss_id
120121

121-
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4)
122+
cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4)
122123

123124
wrkey = WRKey()
124125

@@ -131,7 +132,7 @@ def CacheFn(obj, arg):
131132
del obj
132133
return arg + "extra"
133134

134-
cache = xla_client.weakref_lru_cache(lambda: None, CacheFn, 4)
135+
cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4)
135136

136137
class WRKey:
137138
pass
@@ -151,7 +152,7 @@ class NonWRKey:
151152
with self.assertRaises(TypeError):
152153
weakref.ref(non_wr_key)
153154

154-
cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048)
155+
cache = weakref_lru_cache.weakref_lru_cache(lambda: None, lambda x: 2048)
155156
for _ in range(100):
156157
with self.assertRaises(TypeError):
157158
cache(non_wr_key)
@@ -169,7 +170,9 @@ def __eq__(self, other):
169170
def __hash__(self):
170171
raise ValueError("hash")
171172

172-
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
173+
cache = weakref_lru_cache.weakref_lru_cache(
174+
lambda: None, lambda x, y: y, 2048
175+
)
173176
wrkey = WRKey()
174177
with self.assertRaises(ValueError):
175178
for _ in range(100):
@@ -179,7 +182,9 @@ def testPrintingStats(self):
179182
class WRKey:
180183
pass
181184

182-
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
185+
cache = weakref_lru_cache.weakref_lru_cache(
186+
lambda: None, lambda x, y: y, 2048
187+
)
183188
wrkey = WRKey()
184189
for i in range(10):
185190
cache(wrkey, i)
@@ -203,7 +208,9 @@ def __eq__(self, other):
203208
def __hash__(self):
204209
return hash(self.x)
205210

206-
cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048)
211+
cache = weakref_lru_cache.weakref_lru_cache(
212+
lambda: None, lambda x, y: y, 2048
213+
)
207214
keys = [WRKey(i) for i in range(10)]
208215
for i in range(10):
209216
cache(keys[i], i)
@@ -225,7 +232,7 @@ def CallFn(x, y, *args, **kwargs):
225232
del x, args, kwargs
226233
return y
227234

228-
cache = xla_client.weakref_lru_cache(CacheContextFn, CallFn, 2048)
235+
cache = weakref_lru_cache.weakref_lru_cache(CacheContextFn, CallFn, 2048)
229236

230237
keys = [WRKey() for _ in range(10)]
231238
values = [str(i) for i in range(10)]
@@ -239,7 +246,7 @@ def CallFn(x, y, *args, **kwargs):
239246
[
240247
CacheContextFn,
241248
CallFn,
242-
xla_client._xla.WeakrefLRUCache,
249+
weakref_lru_cache.WeakrefLRUCache,
243250
kwargs,
244251
]
245252
+ [weakref.getweakrefs(key)[0] for key in keys]

jaxlib/xla/BUILD

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ nanobind_pywrap_extension(
6363
":sdy",
6464
":traceback",
6565
":util",
66-
":weakref_lru_cache",
6766
":xla_compiler",
6867
"@com_google_absl//absl/base",
6968
"@com_google_absl//absl/container:flat_hash_map",
@@ -876,29 +875,6 @@ cc_library(
876875
],
877876
)
878877

879-
cc_library(
880-
name = "weakref_lru_cache",
881-
srcs = ["weakref_lru_cache.cc"],
882-
hdrs = ["weakref_lru_cache.h"],
883-
compatible_with = [],
884-
copts = [
885-
"-fexceptions",
886-
"-fno-strict-aliasing",
887-
],
888-
features = ["-use_header_modules"],
889-
deps = [
890-
"@com_google_absl//absl/base:core_headers",
891-
"@com_google_absl//absl/cleanup",
892-
"@com_google_absl//absl/hash",
893-
"@com_google_absl//absl/strings",
894-
"@com_google_absl//absl/synchronization",
895-
"@nanobind",
896-
"@xla//third_party/python_runtime:headers",
897-
"@xla//xla/pjrt:lru_cache",
898-
"@xla//xla/tsl/platform:logging",
899-
],
900-
)
901-
902878
cc_library(
903879
name = "xla_compiler",
904880
srcs = ["xla_compiler.cc"],
@@ -1041,18 +1017,6 @@ py_strict_test(
10411017
]),
10421018
)
10431019

1044-
py_strict_test(
1045-
name = "weakref_lru_cache_test",
1046-
srcs = ["weakref_lru_cache_test.py"],
1047-
deps = [
1048-
":xla_client",
1049-
] + py_deps([
1050-
"absl/flags",
1051-
"absl/logging",
1052-
"absl/testing",
1053-
]),
1054-
)
1055-
10561020
py_strict_test(
10571021
name = "pytree_test",
10581022
srcs = ["pytree_test.py"],

0 commit comments

Comments
 (0)