@@ -632,6 +632,7 @@ def f(x):
632632 python_should_be_executing = False
633633 jit (f )(3 )
634634
635+ @jtu .thread_hostile_test ()
635636 def test_jit_cache_clear (self ):
636637 @jit
637638 def f (x , y ):
@@ -2591,6 +2592,7 @@ def test_block_until_ready_mixed(self):
25912592 self .assertAllClose (pytree [2 ], np .ones (3 ), check_dtypes = False )
25922593 self .assertEqual (pytree [3 ], 4 )
25932594
2595+ @jtu .thread_hostile_test ()
25942596 def test_devicearray_weakref_friendly (self ):
25952597 x = device_put (1. )
25962598 y = weakref .ref (x )
@@ -2739,6 +2741,7 @@ def f(x):
27392741
27402742 self .assertEqual (count (), 1 )
27412743
2744+ @jtu .thread_hostile_test ()
27422745 def test_jit_infer_params_cache (self ):
27432746 def f (x ):
27442747 return x
@@ -3329,6 +3332,7 @@ def test_grad_object_array_error(self):
33293332 with self .assertRaisesRegex (TypeError , ".*is not a valid JAX type" ):
33303333 jax .grad (lambda x : x )(x )
33313334
3335+ @jtu .thread_hostile_test ()
33323336 def test_jit_compilation_time_logging (self ):
33333337 @api .jit
33343338 def f (x ):
@@ -3417,6 +3421,7 @@ def test_trivial_computations(self):
34173421 self .assertNotEqual (z3 .unsafe_buffer_pointer (), x1 .unsafe_buffer_pointer ())
34183422 self .assertEqual (z2 , 1 )
34193423
3424+ @jtu .thread_hostile_test ()
34203425 def test_nested_jit_hoisting (self ):
34213426 @api .jit
34223427 def f (x , y ):
@@ -3454,6 +3459,7 @@ def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
34543459 self .assertEqual (inner_jaxpr .eqns [- 2 ].primitive .name , 'mul' )
34553460 self .assertEqual (inner_jaxpr .eqns [- 1 ].primitive .name , 'add' )
34563461
3462+ @jtu .thread_hostile_test ()
34573463 def test_primitive_compilation_cache (self ):
34583464 with jtu .count_primitive_compiles () as count :
34593465 lax .add (1 , 2 )
@@ -4013,13 +4019,17 @@ def __jax_array__(self):
40134019 a2 = jnp .array (((x , x ), [x , x ]))
40144020 self .assertAllClose (np .array (((1 , 1 ), (1 , 1 ))), a2 )
40154021
4022+ @jtu .thread_hostile_test ()
40164023 def test_eval_shape_weak_type (self ):
40174024 # https://github.com/jax-ml/jax/issues/23302
40184025 arr = jax .numpy .array (1 )
40194026
4027+ def f (x ):
4028+ return jax .numpy .array (x )
4029+
40204030 with jtu .count_jit_tracing_cache_miss () as count :
4021- jax .eval_shape (jax . numpy . array , 1 )
4022- out = jax .eval_shape (jax . numpy . array , 1 )
4031+ jax .eval_shape (f , 1 )
4032+ out = jax .eval_shape (f , 1 )
40234033
40244034 self .assertEqual (count (), 1 )
40254035 self .assertTrue (out .weak_type )
@@ -4138,6 +4148,7 @@ def test_dot_precision_flag(self):
41384148 jaxpr = jax .make_jaxpr (jnp .dot )(x , x )
41394149 self .assertIn ('Precision.HIGH' , str (jaxpr ))
41404150
4151+ @jtu .thread_hostile_test ()
41414152 def test_dot_precision_forces_retrace (self ):
41424153 num_traces = 0
41434154
@@ -4310,6 +4321,7 @@ def test_jnp_array_doesnt_device_put(self):
43104321 api .make_jaxpr (lambda : jnp .array (3 ))()
43114322 self .assertEqual (count (), 0 )
43124323
4324+ @jtu .thread_hostile_test ()
43134325 def test_rank_promotion_forces_retrace (self ):
43144326 num_traces = 0
43154327
@@ -4328,7 +4340,7 @@ def f_jit(x):
43284340
43294341 for f in [f_jit , f_cond ]:
43304342 # Use _read() to read the flag value rather than threadlocal value.
4331- allow_promotion = config . _read ( "jax_numpy_rank_promotion" )
4343+ allow_promotion = jax . numpy_rank_promotion . get_global ( )
43324344 try :
43334345 config .update ("jax_numpy_rank_promotion" , "allow" )
43344346 num_traces = 0
@@ -4350,9 +4362,9 @@ def f(x):
43504362 self .assertGreaterEqual (num_traces , 2 )
43514363 nt = num_traces
43524364 f (x )
4353- self .assertEqual (num_traces , nt + 1 )
4365+ self .assertEqual (num_traces , nt )
43544366 f (x )
4355- self .assertEqual (num_traces , nt + 1 )
4367+ self .assertEqual (num_traces , nt )
43564368 finally :
43574369 config .update ("jax_numpy_rank_promotion" , allow_promotion )
43584370
@@ -4450,6 +4462,7 @@ def foo(x, y, z):
44504462 self .assertEqual (jfoo .__qualname__ , f"make_jaxpr({ foo .__qualname__ } )" )
44514463 self .assertEqual (jfoo .__module__ , "jax" )
44524464
4465+ @jtu .thread_hostile_test ()
44534466 def test_inner_jit_function_retracing (self ):
44544467 # https://github.com/jax-ml/jax/issues/7155
44554468 inner_count = outer_count = 0
@@ -4691,6 +4704,7 @@ def test_mesh_creation_error_message(self):
46914704 with self .assertRaisesRegex (ValueError , "ndim of its first argument" ):
46924705 jax .sharding .Mesh (jax .devices (), ("x" , "y" ))
46934706
4707+ @jtu .thread_hostile_test ()
46944708 def test_jit_boundmethod_reference_cycle (self ):
46954709 class A :
46964710 def __init__ (self ):
@@ -4829,6 +4843,7 @@ class RematTest(jtu.JaxTestCase):
48294843 ('_policy' , partial (jax .remat , policy = lambda * _ , ** __ : False )),
48304844 ('_new' , partial (new_checkpoint , policy = lambda * _ , ** __ : False )),
48314845 ])
4846+ @jtu .thread_hostile_test ()
48324847 def test_remat_basic (self , remat ):
48334848 @remat
48344849 def g (x ):
@@ -5166,6 +5181,7 @@ def f_yesremat(x):
51665181 ('_policy' , partial (jax .remat , policy = lambda * _ , ** __ : False )),
51675182 ('_new' , partial (new_checkpoint , policy = lambda * _ , ** __ : False )),
51685183 ])
5184+ @jtu .thread_hostile_test ()
51695185 def test_remat_no_redundant_flops (self , remat ):
51705186 # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
51715187
@@ -6409,6 +6425,7 @@ def f(x):
64096425 self .assertIn (' sin ' , str (jaxpr ))
64106426 self .assertIn (' cos ' , str (jaxpr ))
64116427
6428+ @jtu .thread_hostile_test ()
64126429 def test_remat_residual_logging (self ):
64136430 def f (x ):
64146431 x = jnp .sin (x )
@@ -9626,11 +9643,8 @@ def foo_bwd(_, g):
96269643
96279644 foo .defvjp (foo_fwd , foo_bwd )
96289645
9629- try :
9630- jax .config .update ('jax_custom_vjp_disable_shape_check' , True )
9646+ with config .custom_vjp_disable_shape_check (True ):
96319647 jax .grad (lambda x , y : foo (x , y ).sum (), 1 )(jnp .ones (3 ), jnp .ones (4 ))
9632- finally :
9633- jax .config .update ('jax_custom_vjp_disable_shape_check' , False )
96349648
96359649 def test_bwd_rule_can_produce_list_or_tuple (self ):
96369650 @jax .custom_vjp
@@ -11114,6 +11128,8 @@ def test_autodidax_smoketest(self):
1111411128 spec .loader .exec_module (autodidax_module )
1111511129
1111611130class GarbageCollectionTest (jtu .JaxTestCase ):
11131+
11132+ @jtu .thread_hostile_test ()
1111711133 def test_xla_gc_callback (self ):
1111811134 # https://github.com/jax-ml/jax/issues/14882
1111911135 x_np = np .arange (10 , dtype = 'int32' )
0 commit comments