Skip to content

Commit 3b9a8f7

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Avoid assuming that jnp.sin will be traced in abstract mesh tests
The test does not clear the JAX caches, and jax.sin is a jitted closure that's shared between all test methods, so there's no guarantee that someone hasn't already traced sine at that same shape before. This only shows up rarely since it depends on the subset of tests assigned to the same test executor. PiperOrigin-RevId: 706706380
1 parent 11e0fdf commit 3b9a8f7

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

tests/pjit_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4523,7 +4523,7 @@ def test_different_devices_wsc_abstract_mesh_cache_hit(self):
45234523
def f(x):
45244524
x = with_sharding_constraint(
45254525
x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x')))
4526-
return jnp.sin(x)
4526+
return jax.lax.sin(x)
45274527

45284528
with (
45294529
jtu.count_jit_tracing_cache_miss() as tracing_count,
@@ -4536,7 +4536,8 @@ def f(x):
45364536
# same num_devices but different devices.
45374537
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
45384538
f(b) # tracing and lowering cache *hit*
4539-
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
4539+
4540+
self.assertEqual(tracing_count(), 1)
45404541
self.assertEqual(lowering_count(), 1)
45414542
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
45424543

tests/shard_map_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ def test_different_devices_shmap_abstract_mesh_cache_hit(self):
811811
def f(x):
812812
x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'),
813813
out_specs=P('i'))(x)
814-
return jnp.sin(x)
814+
return jax.lax.sin(x)
815815

816816
with (
817817
jtu.count_jit_tracing_cache_miss() as tracing_count,
@@ -825,7 +825,7 @@ def f(x):
825825
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
826826
f(b) # tracing and lowering cache *hit*
827827

828-
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
828+
self.assertEqual(tracing_count(), 1)
829829
self.assertEqual(lowering_count(), 1)
830830
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
831831

0 commit comments

Comments
 (0)