1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import itertools
1516
17+ import jax
1618import jax .numpy as jnp
19+ from absl .testing import absltest
20+ from absl .testing import parameterized
21+ from jax import random
1722import numpy as np
1823
1924from flax import nnx
2025
21- import pytest
2226
23-
24- class TestStochastic :
27+ class TestDropout (parameterized .TestCase ):
2528 def test_dropout_internal_rngs (self ):
2629 n = 0
27- m1 = nnx .Dropout (rate = 0.5 , deterministic = False , rngs = nnx .Rngs (dropout = 0 ))
30+ m1 = nnx .Dropout (
31+ rate = 0.5 , deterministic = False , rngs = nnx .Rngs (dropout = 0 )
32+ )
2833 m2 = nnx .Dropout (rate = 0.5 , deterministic = False )
2934 rngs2 = nnx .Rngs (dropout = 0 ).fork ()
3035
@@ -35,34 +40,42 @@ def f(m, x, rngs=None):
3540 return m (x , rngs = rngs )
3641
3742 x = jnp .ones ((1 , 10 ))
38- assert m1 .rngs is not None and m1 .rngs .count [...] == 0
43+ self .assertIsNotNone (m1 .rngs )
44+ self .assertEqual (m1 .rngs .count [...], 0 )
3945
4046 y1 = f (m1 , x )
41- assert n == 1
42- assert m1 .rngs .count [...] == 1
47+ self . assertEqual ( n , 1 )
48+ self . assertEqual ( m1 .rngs .count [...], 1 )
4349 y2 = f (m2 , x , rngs = rngs2 )
44- assert n == 2
45- assert rngs2 .dropout .count [...] == 1
50+ self . assertEqual ( n , 2 )
51+ self . assertEqual ( rngs2 .dropout .count [...], 1 )
4652 np .testing .assert_allclose (y1 , y2 )
4753
4854 y1 = f (m1 , x )
49- assert m1 .rngs .count [...] == 2
55+ self . assertEqual ( m1 .rngs .count [...], 2 )
5056 y2 = f (m2 , x , rngs = rngs2 )
51- assert rngs2 .dropout .count [...] == 2
57+ self . assertEqual ( rngs2 .dropout .count [...], 2 )
5258 np .testing .assert_allclose (y1 , y2 )
5359
54- assert n == 2
60+ self . assertEqual ( n , 2 )
5561
5662 def test_dropout_rng_override (self ):
57- m1 = nnx .Dropout (rate = 0.5 , deterministic = False , rngs = nnx .Rngs (dropout = 0 ))
58- m2 = nnx .Dropout (rate = 0.5 , deterministic = False , rngs = nnx .Rngs (dropout = 1 ))
59- x = jnp .ones ((1 , 10 ))
63+ m1 = nnx .Dropout (
64+ rate = 0.5 , deterministic = False , rngs = nnx .Rngs (dropout = 0 )
65+ )
66+ m2 = nnx .Dropout (
67+ rate = 0.5 , deterministic = False , rngs = nnx .Rngs (dropout = 1 )
68+ )
69+ x = jnp .ones ((10 , 10 ))
6070
6171 y1 = m1 (x )
6272 y2 = m2 (x )
63- with pytest .raises (AssertionError ):
64- np .testing .assert_allclose (y1 , y2 )
73+ self .assertFalse (
74+ np .array_equal (y1 , y2 ),
75+ 'Different RNG seeds should produce different masks' ,
76+ )
6577
78+ # Override m2's seed with m1's seed -- outputs should match
6679 y2 = m2 (x , rngs = nnx .Rngs (dropout = 0 ).fork ())
6780 np .testing .assert_allclose (y1 , y2 )
6881
@@ -71,41 +84,241 @@ def test_dropout_arg_override(self):
7184 x = jnp .ones ((1 , 10 ))
7285
7386 # deterministic call arg provided
74- m (x , deterministic = True )
87+ y_det = m (x , deterministic = True )
88+ np .testing .assert_array_equal (y_det , x )
7589 # deterministic constructor arg provided
7690 m .set_attributes (deterministic = True )
7791 y = m (x )
7892 # both deterministic call and constructor arg provided
79- with pytest . raises ( AssertionError ):
80- np .testing . assert_allclose (
93+ self . assertFalse (
94+ np .array_equal (
8195 y , m (x , deterministic = False , rngs = nnx .Rngs (dropout = 0 ))
82- )
96+ ),
97+ 'deterministic output should differ from stochastic output' ,
98+ )
8399 # no rng arg provided
84100 m .set_attributes (deterministic = False )
85- with pytest . raises (
101+ with self . assertRaisesRegex (
86102 ValueError ,
87- match = '`deterministic` is False, but no `rngs` argument was provided to Dropout ' ,
103+ r '`deterministic` is False.* no `rngs` argument' ,
88104 ):
89105 m (x )
90106
91107 def test_dropout_arg_override_view (self ):
92108 m = nnx .Dropout (rate = 0.5 )
93109 x = jnp .ones ((1 , 10 ))
94110
95- # deterministic call arg provided
96- m (x , deterministic = True )
97- # deterministic constructor arg provided
111+ # deterministic via view
98112 new_m = nnx .view (m , deterministic = True )
99113 y = new_m (x )
114+ np .testing .assert_array_equal (y , x )
100115 # both deterministic call and constructor arg provided
101- with pytest .raises (AssertionError ):
102- np .testing .assert_allclose (
103- y , new_m (x , deterministic = False , rngs = nnx .Rngs (dropout = 0 ))
104- )
116+ self .assertFalse (
117+ np .array_equal (
118+ y ,
119+ new_m (
120+ x , deterministic = False , rngs = nnx .Rngs (dropout = 0 )
121+ ),
122+ ),
123+ 'deterministic output should differ from stochastic output' ,
124+ )
105125 # no rng arg provided
106126 new_m = nnx .view (m , deterministic = False )
107- with pytest . raises (
127+ with self . assertRaisesRegex (
108128 ValueError ,
109- match = '`deterministic` is False, but no `rngs` argument was provided to Dropout ' ,
129+ r '`deterministic` is False.* no `rngs` argument' ,
110130 ):
111- new_m (x )
131+ new_m (x )
132+
133+ def test_deterministic_passthrough (self ):
134+ m = nnx .Dropout (rate = 0.5 , deterministic = True )
135+ x = jnp .ones ((20 , 20 ))
136+ y = m (x )
137+ np .testing .assert_array_equal (y , x )
138+
139+ def test_rate_zero (self ):
140+ m = nnx .Dropout (
141+ rate = 0.0 ,
142+ deterministic = False ,
143+ rngs = nnx .Rngs (dropout = 0 ),
144+ )
145+ x = jnp .ones ((20 , 20 ))
146+ y = m (x )
147+ np .testing .assert_array_equal (y , x )
148+
149+ def test_rate_one (self ):
150+ m = nnx .Dropout (
151+ rate = 1.0 ,
152+ deterministic = False ,
153+ rngs = nnx .Rngs (dropout = 0 ),
154+ )
155+ x = jnp .ones ((20 , 20 ))
156+ y = m (x )
157+ np .testing .assert_array_equal (y , jnp .zeros_like (x ))
158+
159+ def test_rate_one_gradient_not_nan (self ):
160+ m = nnx .Dropout (
161+ rate = 1.0 ,
162+ deterministic = False ,
163+ rngs = nnx .Rngs (dropout = 0 ),
164+ )
165+ x = jnp .ones ((20 , 20 ))
166+ grad = jax .grad (lambda x : jnp .sum (m (x )))(x )
167+ self .assertFalse (jnp .any (jnp .isnan (grad )))
168+ np .testing .assert_array_equal (grad , jnp .zeros_like (x ))
169+
170+ @parameterized .product (
171+ dtype = [jnp .float32 , jnp .float16 , jnp .bfloat16 ],
172+ )
173+ def test_dtypes (self , dtype ):
174+ m = nnx .Dropout (
175+ rate = 0.5 ,
176+ deterministic = False ,
177+ rngs = nnx .Rngs (dropout = 0 ),
178+ )
179+ x = jnp .ones ((10 , 10 ), dtype = dtype )
180+ y = m (x )
181+ self .assertEqual (y .dtype , dtype )
182+
183+ def test_rngs_as_jax_array (self ):
184+ m = nnx .Dropout (rate = 0.5 , deterministic = False )
185+ x = jnp .ones ((10 , 10 ))
186+ key = random .key (0 )
187+ y = m (x , rngs = key )
188+ self .assertTrue (jnp .any (y == 0.0 ))
189+ self .assertTrue (jnp .any (y > 0.0 ))
190+ # Kept values should be scaled by 1/keep_prob = 2.0
191+ np .testing .assert_allclose (
192+ y [y > 0.0 ], 2.0 , rtol = 1e-6
193+ )
194+
195+ def test_rngs_as_nnx_rngs_in_call (self ):
196+ m = nnx .Dropout (rate = 0.5 , deterministic = False )
197+ x = jnp .ones ((10 , 10 ))
198+ y = m (x , rngs = nnx .Rngs (dropout = 0 ))
199+ self .assertTrue (jnp .any (y == 0.0 ))
200+ self .assertTrue (jnp .any (y > 0.0 ))
201+ np .testing .assert_allclose (
202+ y [y > 0.0 ], 2.0 , rtol = 1e-6
203+ )
204+
205+ def test_custom_rng_collection (self ):
206+ m = nnx .Dropout (
207+ rate = 0.5 ,
208+ deterministic = False ,
209+ rng_collection = 'my_dropout' ,
210+ rngs = nnx .Rngs (my_dropout = 0 ),
211+ )
212+ x = jnp .ones ((10 , 10 ))
213+ y = m (x )
214+ self .assertTrue (jnp .any (y == 0.0 ))
215+ self .assertTrue (jnp .any (y > 0.0 ))
216+
217+ def test_invalid_rngs_type_constructor (self ):
218+ with self .assertRaisesRegex (
219+ TypeError ,
220+ r'rngs must be a Rngs, RngStream or None' ,
221+ ):
222+ nnx .Dropout (rate = 0.5 , rngs = 'invalid' )
223+
224+ def test_invalid_rngs_type_call (self ):
225+ m = nnx .Dropout (rate = 0.5 , deterministic = False )
226+ x = jnp .ones ((10 , 10 ))
227+ with self .assertRaisesRegex (
228+ TypeError ,
229+ r'rngs must be a Rngs, RngStream or jax\.Array' ,
230+ ):
231+ m (x , rngs = 'invalid' )
232+
233+ @parameterized .parameters (
234+ {
235+ 'num_dims' : 2 ,
236+ 'broadcast_dims' : (1 ,),
237+ 'slice_fn' : lambda out , i : out [i , :],
238+ 'summed_total' : 2 * 10 ,
239+ },
240+ {
241+ 'num_dims' : 2 ,
242+ 'broadcast_dims' : (0 ,),
243+ 'slice_fn' : lambda out , i : out [:, i ],
244+ 'summed_total' : 2 * 10 ,
245+ },
246+ {
247+ 'num_dims' : 3 ,
248+ 'broadcast_dims' : (1 , 2 ),
249+ 'slice_fn' : lambda out , i : out [i , :, :],
250+ 'summed_total' : 2 * 10 * 10 ,
251+ },
252+ {
253+ 'num_dims' : 3 ,
254+ 'broadcast_dims' : (1 ,),
255+ 'slice_fn' : lambda out , i , j : out [i , :, j ],
256+ 'summed_total' : 2 * 10 ,
257+ },
258+ {
259+ 'num_dims' : 4 ,
260+ 'broadcast_dims' : (0 , 2 , 3 ),
261+ 'slice_fn' : lambda out , i : out [:, i , :, :],
262+ 'summed_total' : 2 * 10 * 10 * 10 ,
263+ },
264+ {
265+ 'num_dims' : 4 ,
266+ 'broadcast_dims' : (0 , 1 ),
267+ 'slice_fn' : lambda out , i , j : out [:, :, i , j ],
268+ 'summed_total' : 2 * 10 * 10 ,
269+ },
270+ {
271+ 'num_dims' : 4 ,
272+ 'broadcast_dims' : (3 ,),
273+ 'slice_fn' : lambda out , i , j , k : out [i , j , k , :],
274+ 'summed_total' : 2 * 10 ,
275+ },
276+ )
277+ def test_broadcast_dims (
278+ self , num_dims , broadcast_dims , slice_fn , summed_total
279+ ):
280+ m = nnx .Dropout (
281+ rate = 0.5 ,
282+ broadcast_dims = broadcast_dims ,
283+ deterministic = False ,
284+ rngs = nnx .Rngs (dropout = 0 ),
285+ )
286+ x = jnp .ones ((10 ,) * num_dims )
287+ out = m (x )
288+
289+ n_free = num_dims - len (broadcast_dims )
290+ for indices in itertools .product (range (10 ), repeat = n_free ):
291+ self .assertIn (
292+ float (slice_fn (out , * indices ).sum ()),
293+ (0 , summed_total ),
294+ )
295+
296+ def test_rate_stats (self ):
297+ n_trials = 10
298+ rootkey = random .key (0 )
299+ for rate in np .arange (0.1 , 1.0 , 0.1 ):
300+ rootkey , subkey = random .split (rootkey )
301+ m = nnx .Dropout (rate = rate , deterministic = False )
302+ nonzero_counts = 0
303+ for key in random .split (subkey , n_trials ):
304+ y = m (
305+ jnp .ones ((100 , 100 )),
306+ rngs = nnx .Rngs (dropout = key ),
307+ )
308+ nonzero_counts += np .sum (y > 0.0 )
309+ all_counts = np .prod ((100 , 100 , n_trials ))
310+ frac = nonzero_counts / all_counts
311+ keep_rate = 1.0 - rate
312+ # check within 4 sigma
313+ delta = (
314+ 4
315+ * np .sqrt (rate * keep_rate )
316+ / np .sqrt (all_counts )
317+ )
318+ self .assertTrue (
319+ keep_rate - delta < frac < keep_rate + delta
320+ )
321+
322+
323+ if __name__ == '__main__' :
324+ absltest .main ()
0 commit comments