Skip to content

Commit 57826d8

Browse files
yashk2810jax authors
authored andcommitted
Add a no input memories_test and enable memories test on vf 2x2
PiperOrigin-RevId: 641361865
1 parent 0d047a1 commit 57826d8

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/memories_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,22 @@ def _check_mem_kind(self, executable_kind, out_sharding, expected_kind):
616616
self.assertEqual(out_kind, expected_kind)
617617
self.assertEqual(executable_kind, expected_kind)
618618

619+
def test_compute_no_inputs(self):
620+
mesh = jtu.create_global_mesh((4,), ('data'))
621+
622+
tpu_sharding = NamedSharding(mesh, P('data'))
623+
cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host')
624+
625+
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
626+
def init():
627+
tpu_array = jax.random.normal(jax.random.key(42), (16,16))
628+
cpu_array = jax.random.normal(jax.random.key(42), (16,16))
629+
return tpu_array, cpu_array
630+
631+
tpu_array, cpu_array = init()
632+
self.assertEqual(tpu_array.sharding, tpu_sharding)
633+
self.assertEqual(cpu_array.sharding, cpu_sharding)
634+
619635
def test_compute_on_basic(self):
620636
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
621637

0 commit comments

Comments
 (0)