Skip to content

Commit bbfab0a

Browse files
author
Flax Authors
committed
Merge pull request #5295 from samanklesaria:remove_chex
PiperOrigin-RevId: 878167338
2 parents 357b298 + afc169e commit bbfab0a

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@ authors = [
1111
{name = "Flax team", email = "flax-dev@google.com"},
1212
]
1313
dependencies = [
14-
# temporary numpy version fix due to
15-
# https://github.com/google/flax/issues/5162
16-
# and
17-
# https://github.com/google-deepmind/chex/issues/424
18-
"numpy>=1.23.2,<2.4.0; python_version>='3.11'",
19-
"numpy>=1.26.0,<2.4.0; python_version>='3.12'",
14+
"numpy>=1.23.2",
2015
# keep in sync with jax-version in .github/workflows/build.yml
2116
"jax>=0.8.1",
2217
"msgpack",
@@ -43,7 +38,6 @@ dynamic = ["version", "readme"]
4338
[project.optional-dependencies]
4439
testing = [
4540
"clu",
46-
"chex",
4741
"clu<=0.0.9; python_version<'3.10'",
4842
"einops",
4943
"gymnasium[atari]; python_version<'3.14'",

tests/jax_utils_test.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from absl.testing import absltest
2222
from absl.testing import parameterized
23-
import chex
2423
from flax import jax_utils
2524
import jax
2625
import jax.numpy as jnp
@@ -29,13 +28,32 @@
2928
NDEV = 4
3029

3130

32-
class PadShardUnpadTest(chex.TestCase):
31+
def assert_max_traces(n):
32+
"""Decorator to assert that a function is traced at most n times."""
33+
from functools import wraps
34+
def decorator(fn):
35+
trace_count = {'count': 0}
36+
37+
@wraps(fn)
38+
def wrapped(*args, **kwargs):
39+
trace_count['count'] += 1
40+
if trace_count['count'] > n:
41+
raise AssertionError(
42+
f"Function was traced {trace_count['count']} times, "
43+
f"expected at most {n} traces"
44+
)
45+
return fn(*args, **kwargs)
46+
47+
wrapped.trace_count = trace_count
48+
return wrapped
49+
return decorator
50+
51+
52+
class PadShardUnpadTest(parameterized.TestCase):
3353
BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1]
3454
DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32]
3555

36-
def tearDown(self):
37-
chex.clear_trace_counter()
38-
super().tearDown()
56+
3957

4058
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
4159
def test_basics(self, dtype, bs):
@@ -47,7 +65,7 @@ def add(a, b):
4765

4866
x = np.arange(bs, dtype=dtype)
4967
y = add(x, 10 * x)
50-
chex.assert_type(y.dtype, x.dtype)
68+
self.assertEqual(y.dtype, x.dtype)
5169
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
5270

5371
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
@@ -59,24 +77,22 @@ def add(a, b):
5977

6078
x = jnp.arange(bs, dtype=dtype)
6179
y = add(dict(a=x), (10 * x,))
62-
chex.assert_type(y.dtype, x.dtype)
80+
self.assertEqual(y.dtype, x.dtype)
6381
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
6482

6583
@parameterized.parameters(DTYPES)
6684
def test_min_device_batch_avoids_recompile(self, dtype):
6785
@partial(jax_utils.pad_shard_unpad, static_argnums=())
6886
@jax.jit
69-
@chex.assert_max_traces(n=1)
87+
@assert_max_traces(n=1)
7088
def add(a, b):
7189
b = jnp.asarray(b, dtype=dtype)
7290
return a + b
7391

74-
chex.clear_trace_counter()
75-
7692
for bs in self.BATCH_SIZES:
7793
x = jnp.arange(bs, dtype=dtype)
7894
y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg
79-
chex.assert_type(y.dtype, x.dtype)
95+
self.assertEqual(y.dtype, x.dtype)
8096
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
8197

8298
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
@@ -87,7 +103,7 @@ def add(a, b):
87103

88104
x = jnp.arange(bs, dtype=dtype)
89105
y = add(x, 10)
90-
chex.assert_type(y.dtype, x.dtype)
106+
self.assertEqual(y.dtype, x.dtype)
91107
np.testing.assert_allclose(np.float64(y), np.float64(x + 10))
92108

93109
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
@@ -102,7 +118,7 @@ def add(params, a, *, b):
102118

103119
x = jnp.arange(bs, dtype=dtype)
104120
y = add(5, x, b=10)
105-
chex.assert_type(y.dtype, x.dtype)
121+
self.assertEqual(y.dtype, x.dtype)
106122
np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10))
107123

108124

tests/linen/linen_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any
2020

2121
from absl.testing import absltest, parameterized
22-
import chex
2322
from flax import ids
2423
from flax import linen as nn
2524
from flax.linen import fp8_ops
@@ -960,7 +959,7 @@ def _strip_partitioning(x):
960959
np.testing.assert_array_equal(got_scale, expected_scale)
961960

962961
# Compares the rest of PyTree nodes
963-
chex.assert_trees_all_close(expected, got)
962+
jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y), expected, got)
964963

965964

966965
class StochasticTest(parameterized.TestCase):

0 commit comments

Comments
 (0)