Skip to content

Commit 4d175e9

Browse files
committed
Expand Dropout test coverage and migrate to absltest
1 parent 9a1bc4e commit 4d175e9

File tree

1 file changed

+246
-33
lines changed

1 file changed

+246
-33
lines changed

tests/nnx/nn/stochastic_test.py

Lines changed: 246 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516

17+
import jax
1618
import jax.numpy as jnp
19+
from absl.testing import absltest
20+
from absl.testing import parameterized
21+
from jax import random
1722
import numpy as np
1823

1924
from flax import nnx
2025

21-
import pytest
2226

23-
24-
class TestStochastic:
27+
class TestDropout(parameterized.TestCase):
2528
def test_dropout_internal_rngs(self):
2629
n = 0
27-
m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))
30+
m1 = nnx.Dropout(
31+
rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)
32+
)
2833
m2 = nnx.Dropout(rate=0.5, deterministic=False)
2934
rngs2 = nnx.Rngs(dropout=0).fork()
3035

@@ -35,34 +40,42 @@ def f(m, x, rngs=None):
3540
return m(x, rngs=rngs)
3641

3742
x = jnp.ones((1, 10))
38-
assert m1.rngs is not None and m1.rngs.count[...] == 0
43+
self.assertIsNotNone(m1.rngs)
44+
self.assertEqual(m1.rngs.count[...], 0)
3945

4046
y1 = f(m1, x)
41-
assert n == 1
42-
assert m1.rngs.count[...] == 1
47+
self.assertEqual(n, 1)
48+
self.assertEqual(m1.rngs.count[...], 1)
4349
y2 = f(m2, x, rngs=rngs2)
44-
assert n == 2
45-
assert rngs2.dropout.count[...] == 1
50+
self.assertEqual(n, 2)
51+
self.assertEqual(rngs2.dropout.count[...], 1)
4652
np.testing.assert_allclose(y1, y2)
4753

4854
y1 = f(m1, x)
49-
assert m1.rngs.count[...] == 2
55+
self.assertEqual(m1.rngs.count[...], 2)
5056
y2 = f(m2, x, rngs=rngs2)
51-
assert rngs2.dropout.count[...] == 2
57+
self.assertEqual(rngs2.dropout.count[...], 2)
5258
np.testing.assert_allclose(y1, y2)
5359

54-
assert n == 2
60+
self.assertEqual(n, 2)
5561

5662
def test_dropout_rng_override(self):
57-
m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0))
58-
m2 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=1))
59-
x = jnp.ones((1, 10))
63+
m1 = nnx.Dropout(
64+
rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)
65+
)
66+
m2 = nnx.Dropout(
67+
rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=1)
68+
)
69+
x = jnp.ones((10, 10))
6070

6171
y1 = m1(x)
6272
y2 = m2(x)
63-
with pytest.raises(AssertionError):
64-
np.testing.assert_allclose(y1, y2)
73+
self.assertFalse(
74+
np.array_equal(y1, y2),
75+
'Different RNG seeds should produce different masks',
76+
)
6577

78+
# Override m2's seed with m1's seed -- outputs should match
6679
y2 = m2(x, rngs=nnx.Rngs(dropout=0).fork())
6780
np.testing.assert_allclose(y1, y2)
6881

@@ -71,41 +84,241 @@ def test_dropout_arg_override(self):
7184
x = jnp.ones((1, 10))
7285

7386
# deterministic call arg provided
74-
m(x, deterministic=True)
87+
y_det = m(x, deterministic=True)
88+
np.testing.assert_array_equal(y_det, x)
7589
# deterministic constructor arg provided
7690
m.set_attributes(deterministic=True)
7791
y = m(x)
7892
# both deterministic call and constructor arg provided
79-
with pytest.raises(AssertionError):
80-
np.testing.assert_allclose(
93+
self.assertFalse(
94+
np.array_equal(
8195
y, m(x, deterministic=False, rngs=nnx.Rngs(dropout=0))
82-
)
96+
),
97+
'deterministic output should differ from stochastic output',
98+
)
8399
# no rng arg provided
84100
m.set_attributes(deterministic=False)
85-
with pytest.raises(
101+
with self.assertRaisesRegex(
86102
ValueError,
87-
match='`deterministic` is False, but no `rngs` argument was provided to Dropout',
103+
r'`deterministic` is False.*no `rngs` argument',
88104
):
89105
m(x)
90106

91107
def test_dropout_arg_override_view(self):
92108
m = nnx.Dropout(rate=0.5)
93109
x = jnp.ones((1, 10))
94110

95-
# deterministic call arg provided
96-
m(x, deterministic=True)
97-
# deterministic constructor arg provided
111+
# deterministic via view
98112
new_m = nnx.view(m, deterministic=True)
99113
y = new_m(x)
114+
np.testing.assert_array_equal(y, x)
100115
# both deterministic call and constructor arg provided
101-
with pytest.raises(AssertionError):
102-
np.testing.assert_allclose(
103-
y, new_m(x, deterministic=False, rngs=nnx.Rngs(dropout=0))
104-
)
116+
self.assertFalse(
117+
np.array_equal(
118+
y,
119+
new_m(
120+
x, deterministic=False, rngs=nnx.Rngs(dropout=0)
121+
),
122+
),
123+
'deterministic output should differ from stochastic output',
124+
)
105125
# no rng arg provided
106126
new_m = nnx.view(m, deterministic=False)
107-
with pytest.raises(
127+
with self.assertRaisesRegex(
108128
ValueError,
109-
match='`deterministic` is False, but no `rngs` argument was provided to Dropout',
129+
r'`deterministic` is False.*no `rngs` argument',
110130
):
111-
new_m(x)
131+
new_m(x)
132+
133+
def test_deterministic_passthrough(self):
134+
m = nnx.Dropout(rate=0.5, deterministic=True)
135+
x = jnp.ones((20, 20))
136+
y = m(x)
137+
np.testing.assert_array_equal(y, x)
138+
139+
def test_rate_zero(self):
140+
m = nnx.Dropout(
141+
rate=0.0,
142+
deterministic=False,
143+
rngs=nnx.Rngs(dropout=0),
144+
)
145+
x = jnp.ones((20, 20))
146+
y = m(x)
147+
np.testing.assert_array_equal(y, x)
148+
149+
def test_rate_one(self):
150+
m = nnx.Dropout(
151+
rate=1.0,
152+
deterministic=False,
153+
rngs=nnx.Rngs(dropout=0),
154+
)
155+
x = jnp.ones((20, 20))
156+
y = m(x)
157+
np.testing.assert_array_equal(y, jnp.zeros_like(x))
158+
159+
def test_rate_one_gradient_not_nan(self):
160+
m = nnx.Dropout(
161+
rate=1.0,
162+
deterministic=False,
163+
rngs=nnx.Rngs(dropout=0),
164+
)
165+
x = jnp.ones((20, 20))
166+
grad = jax.grad(lambda x: jnp.sum(m(x)))(x)
167+
self.assertFalse(jnp.any(jnp.isnan(grad)))
168+
np.testing.assert_array_equal(grad, jnp.zeros_like(x))
169+
170+
@parameterized.product(
171+
dtype=[jnp.float32, jnp.float16, jnp.bfloat16],
172+
)
173+
def test_dtypes(self, dtype):
174+
m = nnx.Dropout(
175+
rate=0.5,
176+
deterministic=False,
177+
rngs=nnx.Rngs(dropout=0),
178+
)
179+
x = jnp.ones((10, 10), dtype=dtype)
180+
y = m(x)
181+
self.assertEqual(y.dtype, dtype)
182+
183+
def test_rngs_as_jax_array(self):
184+
m = nnx.Dropout(rate=0.5, deterministic=False)
185+
x = jnp.ones((10, 10))
186+
key = random.key(0)
187+
y = m(x, rngs=key)
188+
self.assertTrue(jnp.any(y == 0.0))
189+
self.assertTrue(jnp.any(y > 0.0))
190+
# Kept values should be scaled by 1/keep_prob = 2.0
191+
np.testing.assert_allclose(
192+
y[y > 0.0], 2.0, rtol=1e-6
193+
)
194+
195+
def test_rngs_as_nnx_rngs_in_call(self):
196+
m = nnx.Dropout(rate=0.5, deterministic=False)
197+
x = jnp.ones((10, 10))
198+
y = m(x, rngs=nnx.Rngs(dropout=0))
199+
self.assertTrue(jnp.any(y == 0.0))
200+
self.assertTrue(jnp.any(y > 0.0))
201+
np.testing.assert_allclose(
202+
y[y > 0.0], 2.0, rtol=1e-6
203+
)
204+
205+
def test_custom_rng_collection(self):
206+
m = nnx.Dropout(
207+
rate=0.5,
208+
deterministic=False,
209+
rng_collection='my_dropout',
210+
rngs=nnx.Rngs(my_dropout=0),
211+
)
212+
x = jnp.ones((10, 10))
213+
y = m(x)
214+
self.assertTrue(jnp.any(y == 0.0))
215+
self.assertTrue(jnp.any(y > 0.0))
216+
217+
def test_invalid_rngs_type_constructor(self):
218+
with self.assertRaisesRegex(
219+
TypeError,
220+
r'rngs must be a Rngs, RngStream or None',
221+
):
222+
nnx.Dropout(rate=0.5, rngs='invalid')
223+
224+
def test_invalid_rngs_type_call(self):
225+
m = nnx.Dropout(rate=0.5, deterministic=False)
226+
x = jnp.ones((10, 10))
227+
with self.assertRaisesRegex(
228+
TypeError,
229+
r'rngs must be a Rngs, RngStream or jax\.Array',
230+
):
231+
m(x, rngs='invalid')
232+
233+
@parameterized.parameters(
234+
{
235+
'num_dims': 2,
236+
'broadcast_dims': (1,),
237+
'slice_fn': lambda out, i: out[i, :],
238+
'summed_total': 2 * 10,
239+
},
240+
{
241+
'num_dims': 2,
242+
'broadcast_dims': (0,),
243+
'slice_fn': lambda out, i: out[:, i],
244+
'summed_total': 2 * 10,
245+
},
246+
{
247+
'num_dims': 3,
248+
'broadcast_dims': (1, 2),
249+
'slice_fn': lambda out, i: out[i, :, :],
250+
'summed_total': 2 * 10 * 10,
251+
},
252+
{
253+
'num_dims': 3,
254+
'broadcast_dims': (1,),
255+
'slice_fn': lambda out, i, j: out[i, :, j],
256+
'summed_total': 2 * 10,
257+
},
258+
{
259+
'num_dims': 4,
260+
'broadcast_dims': (0, 2, 3),
261+
'slice_fn': lambda out, i: out[:, i, :, :],
262+
'summed_total': 2 * 10 * 10 * 10,
263+
},
264+
{
265+
'num_dims': 4,
266+
'broadcast_dims': (0, 1),
267+
'slice_fn': lambda out, i, j: out[:, :, i, j],
268+
'summed_total': 2 * 10 * 10,
269+
},
270+
{
271+
'num_dims': 4,
272+
'broadcast_dims': (3,),
273+
'slice_fn': lambda out, i, j, k: out[i, j, k, :],
274+
'summed_total': 2 * 10,
275+
},
276+
)
277+
def test_broadcast_dims(
278+
self, num_dims, broadcast_dims, slice_fn, summed_total
279+
):
280+
m = nnx.Dropout(
281+
rate=0.5,
282+
broadcast_dims=broadcast_dims,
283+
deterministic=False,
284+
rngs=nnx.Rngs(dropout=0),
285+
)
286+
x = jnp.ones((10,) * num_dims)
287+
out = m(x)
288+
289+
n_free = num_dims - len(broadcast_dims)
290+
for indices in itertools.product(range(10), repeat=n_free):
291+
self.assertIn(
292+
float(slice_fn(out, *indices).sum()),
293+
(0, summed_total),
294+
)
295+
296+
def test_rate_stats(self):
297+
n_trials = 10
298+
rootkey = random.key(0)
299+
for rate in np.arange(0.1, 1.0, 0.1):
300+
rootkey, subkey = random.split(rootkey)
301+
m = nnx.Dropout(rate=rate, deterministic=False)
302+
nonzero_counts = 0
303+
for key in random.split(subkey, n_trials):
304+
y = m(
305+
jnp.ones((100, 100)),
306+
rngs=nnx.Rngs(dropout=key),
307+
)
308+
nonzero_counts += np.sum(y > 0.0)
309+
all_counts = np.prod((100, 100, n_trials))
310+
frac = nonzero_counts / all_counts
311+
keep_rate = 1.0 - rate
312+
# check within 4 sigma
313+
delta = (
314+
4
315+
* np.sqrt(rate * keep_rate)
316+
/ np.sqrt(all_counts)
317+
)
318+
self.assertTrue(
319+
keep_rate - delta < frac < keep_rate + delta
320+
)
321+
322+
323+
if __name__ == '__main__':
324+
absltest.main()

0 commit comments

Comments
 (0)