Skip to content

Commit 1cdfa4e

Browse files
committed
Remove chex dependency
1 parent babce88 commit 1cdfa4e

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@ authors = [
1111
{name = "Flax team", email = "flax-dev@google.com"},
1212
]
1313
dependencies = [
14-
# temporary numpy version fix due to
15-
# https://github.com/google/flax/issues/5162
16-
# and
17-
# https://github.com/google-deepmind/chex/issues/424
18-
"numpy>=1.23.2,<2.4.0; python_version>='3.11'",
19-
"numpy>=1.26.0,<2.4.0; python_version>='3.12'",
14+
"numpy>=1.23.2",
2015
# keep in sync with jax-version in .github/workflows/build.yml
2116
"jax>=0.8.1",
2217
"msgpack",
@@ -43,7 +38,6 @@ dynamic = ["version", "readme"]
4338
[project.optional-dependencies]
4439
testing = [
4540
"clu",
46-
"chex",
4741
"clu<=0.0.9; python_version<'3.10'",
4842
"einops",
4943
"gymnasium[atari]; python_version<'3.14'",

tests/jax_utils_test.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from absl.testing import absltest
2222
from absl.testing import parameterized
23-
import chex
2423
from flax import jax_utils
2524
import jax
2625
import jax.numpy as jnp
@@ -29,13 +28,30 @@
2928
NDEV = 4
3029

3130

32-
class PadShardUnpadTest(chex.TestCase):
31+
def assert_max_traces(n):
32+
"""Decorator to assert that a function is traced at most n times."""
33+
def decorator(fn):
34+
trace_count = {'count': 0}
35+
36+
def wrapped(*args, **kwargs):
37+
trace_count['count'] += 1
38+
if trace_count['count'] > n:
39+
raise AssertionError(
40+
f"Function was traced {trace_count['count']} times, "
41+
f"expected at most {n} traces"
42+
)
43+
return fn(*args, **kwargs)
44+
45+
wrapped.trace_count = trace_count
46+
return wrapped
47+
return decorator
48+
49+
50+
class PadShardUnpadTest(parameterized.TestCase):
3351
BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1]
3452
DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32]
3553

36-
def tearDown(self):
37-
chex.clear_trace_counter()
38-
super().tearDown()
54+
3955

4056
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
4157
def test_basics(self, dtype, bs):
@@ -47,7 +63,7 @@ def add(a, b):
4763

4864
x = np.arange(bs, dtype=dtype)
4965
y = add(x, 10 * x)
50-
chex.assert_type(y.dtype, x.dtype)
66+
self.assertEqual(y.dtype, x.dtype)
5167
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
5268

5369
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
@@ -59,24 +75,22 @@ def add(a, b):
5975

6076
x = jnp.arange(bs, dtype=dtype)
6177
y = add(dict(a=x), (10 * x,))
62-
chex.assert_type(y.dtype, x.dtype)
78+
self.assertEqual(y.dtype, x.dtype)
6379
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
6480

6581
@parameterized.parameters(DTYPES)
6682
def test_min_device_batch_avoids_recompile(self, dtype):
6783
@partial(jax_utils.pad_shard_unpad, static_argnums=())
6884
@jax.jit
69-
@chex.assert_max_traces(n=1)
85+
@assert_max_traces(n=1)
7086
def add(a, b):
7187
b = jnp.asarray(b, dtype=dtype)
7288
return a + b
7389

74-
chex.clear_trace_counter()
75-
7690
for bs in self.BATCH_SIZES:
7791
x = jnp.arange(bs, dtype=dtype)
7892
y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg
79-
chex.assert_type(y.dtype, x.dtype)
93+
self.assertEqual(y.dtype, x.dtype)
8094
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
8195

8296
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
@@ -87,7 +101,7 @@ def add(a, b):
87101

88102
x = jnp.arange(bs, dtype=dtype)
89103
y = add(x, 10)
90-
chex.assert_type(y.dtype, x.dtype)
104+
self.assertEqual(y.dtype, x.dtype)
91105
np.testing.assert_allclose(np.float64(y), np.float64(x + 10))
92106

93107
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
@@ -102,7 +116,7 @@ def add(params, a, *, b):
102116

103117
x = jnp.arange(bs, dtype=dtype)
104118
y = add(5, x, b=10)
105-
chex.assert_type(y.dtype, x.dtype)
119+
self.assertEqual(y.dtype, x.dtype)
106120
np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10))
107121

108122

tests/linen/linen_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any
2020

2121
from absl.testing import absltest, parameterized
22-
import chex
2322
from flax import ids
2423
from flax import linen as nn
2524
from flax.linen import fp8_ops
@@ -960,7 +959,7 @@ def _strip_partitioning(x):
960959
np.testing.assert_array_equal(got_scale, expected_scale)
961960

962961
# Compares the rest of PyTree nodes
963-
chex.assert_trees_all_close(expected, got)
962+
jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y), expected, got)
964963

965964

966965
class StochasticTest(parameterized.TestCase):

0 commit comments

Comments
 (0)