Skip to content

Commit 4d60db1

Browse files
Add test_compute_on_host_shared_sharding in memories_test
PiperOrigin-RevId: 698250352
1 parent 6c291d6 commit 4d60db1

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/memories_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,46 @@ def h(x):
808808
self.assertArraysEqual(out2, inp * 6)
809809
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
810810

811+
def test_compute_on_host_shared_sharding(self):
812+
mesh = jtu.create_mesh((2,), ("x"))
813+
device_sharding = NamedSharding(mesh, P("x"))
814+
host_sharding = device_sharding.with_memory_kind("pinned_host")
815+
816+
@compute_on("device_host")
817+
@functools.partial(
818+
jax.jit,
819+
in_shardings=(host_sharding, device_sharding),
820+
out_shardings=(host_sharding, device_sharding),
821+
donate_argnums=(0, 1),
822+
)
823+
def host_func(x, y):
824+
return (x * y), ((x**2) * (y**2))
825+
826+
@functools.partial(
827+
jax.jit,
828+
in_shardings=(host_sharding, device_sharding),
829+
out_shardings=(host_sharding, device_sharding),
830+
donate_argnums=(0),
831+
)
832+
def device_func(host_data, device_data):
833+
host_data, device_data = host_func(host_data, device_data)
834+
device_data = device_data * 2
835+
host_data, device_data = host_func(host_data, device_data)
836+
return (host_data, device_data)
837+
838+
input_x = jnp.ones(8)
839+
input_host = jax.device_put(input_x, host_sharding)
840+
841+
input_device = jnp.arange(8)
842+
input_device = jnp.where(input_device < 4, 0, 1)
843+
input_device = jax.device_put(input_device, device_sharding)
844+
845+
output_host, output_device = device_func(input_host, input_device)
846+
self.assertEqual(output_host.sharding.memory_kind, 'pinned_host')
847+
self.assertEqual(output_device.sharding.memory_kind, 'device')
848+
self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.])
849+
self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.])
850+
811851
def test_compute_on_basic_inline(self):
812852
@compute_on('device_host')
813853
@jax.jit

0 commit comments

Comments
 (0)