Skip to content

Commit 472e5a2

Browse files
TST: Speed up repeated jax tests (#649)
Add pytest option --jax-pcc for the jax persistent compilation cache and add notes to developer contributing guide This uses a jax config to set cache options. Total size of cache for all tests (well, really four main ones) is only 2.7M on linux, 416 files. Reduces runtime from 4.16s call test_optimizers.py::test_fit[SBR] 2.98s call test_optimizers.py::test_complexity_not_fitted[SBR] 2.89s call test_optimizers.py::test_pickle[SBR-opt_args1] 2.89s call test_optimizers.py::test_sbr_accurate To: 1.94s call test_optimizers.py::test_fit[SBR] 1.56s call test_optimizers.py::test_sbr_accurate 1.23s call test_optimizers.py::test_pickle[SBR-opt_args1] 1.18s call test_optimizers.py::test_complexity_not_fitted[SBR]
1 parent b866bc7 commit 472e5a2

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

docs/contributing.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,16 @@ Install our pre-commit script via
141141
142142
pre-commit install
143143
144+
Pre-commit will automatically check all future commits for code style.
144145
To be accepted your code should conform to PEP8 and pass all unit tests.
145146
Code can be tested by invoking
146147

147148
.. code-block:: bash
148149
149-
pytest
150-
151-
Pre-commit will automatically check all future commits for code style.
150+
pytest --jax-pcc
152151
152+
The ``jax-pcc`` flag is optional and will speed up the jax tests on repeated runs
153+
by caching compiled functions to disk.
153154

154155

155156
Coding Guidelines

test/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Shared pytest fixtures for unit tests.
33
"""
44
from pathlib import Path
5+
from tempfile import gettempdir
56

67
import jax
78
import numpy as np
@@ -34,6 +35,14 @@ def pytest_addoption(parser):
3435
" test_notebooks.test_external"
3536
),
3637
)
38+
parser.addoption(
39+
"--jax-pcc",
40+
action="store_true",
41+
help=(
42+
"Whether to cache @jax.jit compilations to disk."
43+
"It speeds up the slowest tests appx 2x, writing files to /tmp/jax_cache"
44+
),
45+
)
3746

3847

3948
def pytest_generate_tests(metafunc):
@@ -47,6 +56,17 @@ def pytest_generate_tests(metafunc):
4756
)
4857

4958

59+
@pytest.fixture(scope="session")
60+
def set_jax_pcc_env(request):
61+
if request.config.getoption("--jax-pcc"):
62+
cache_dir = Path(gettempdir()) / "jax_cache/"
63+
cache_dir.mkdir(exist_ok=True)
64+
jax.config.update("jax_compilation_cache_dir", str(cache_dir))
65+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
66+
jax.config.update("jax_persistent_cache_min_compile_time_secs", -1)
67+
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
68+
69+
5070
@pytest.fixture(scope="session")
5171
def data_1d():
5272
t = np.linspace(0, 1, 10)

test/test_optimizers/test_optimizers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def data(request):
130130
],
131131
ids=lambda param: type(param),
132132
)
133-
def test_fit(data_derivative_1d, optimizer):
133+
def test_fit(data_derivative_1d, optimizer, set_jax_pcc_env):
134134
x, x_dot = data_derivative_1d
135135
if len(x.shape) == 1:
136136
x = x.reshape(-1, 1)
@@ -168,7 +168,7 @@ def test_not_fitted(optimizer):
168168
],
169169
ids=type,
170170
)
171-
def test_complexity_not_fitted(optimizer, data_derivative_2d):
171+
def test_complexity_not_fitted(optimizer, data_derivative_2d, set_jax_pcc_env):
172172
with pytest.raises(NotFittedError):
173173
optimizer.complexity
174174

@@ -389,7 +389,7 @@ def test_sbr_bad_parameters(params):
389389
SBR(**params)
390390

391391

392-
def test_sbr_accurate():
392+
def test_sbr_accurate(set_jax_pcc_env):
393393
# It's really hard to tune SBR to get desired shrinkage
394394
# This just tests that SBR fits "close" to unregularized regression
395395
x = np.tile(np.eye(2), 4).reshape((-1, 2))
@@ -1193,7 +1193,7 @@ def test_remove_and_decrement():
11931193
(TrappingSR3, {"_n_tgts": 3, "_include_bias": True}),
11941194
),
11951195
)
1196-
def test_pickle(data_lorenz, opt_cls, opt_args):
1196+
def test_pickle(data_lorenz, opt_cls, opt_args, set_jax_pcc_env):
11971197
x, t = data_lorenz
11981198
y = PolynomialLibrary(degree=2).fit_transform(x)
11991199
opt = opt_cls(**opt_args).fit(y, x)

0 commit comments

Comments
 (0)