@@ -808,6 +808,46 @@ def h(x):
808808 self .assertArraysEqual (out2 , inp * 6 )
809809 self .assertEqual (out2 .sharding .memory_kind , 'pinned_host' )
810810
811+ def test_compute_on_host_shared_sharding (self ):
812+ mesh = jtu .create_mesh ((2 ,), ("x" ))
813+ device_sharding = NamedSharding (mesh , P ("x" ))
814+ host_sharding = device_sharding .with_memory_kind ("pinned_host" )
815+
816+ @compute_on ("device_host" )
817+ @functools .partial (
818+ jax .jit ,
819+ in_shardings = (host_sharding , device_sharding ),
820+ out_shardings = (host_sharding , device_sharding ),
821+ donate_argnums = (0 , 1 ),
822+ )
823+ def host_func (x , y ):
824+ return (x * y ), ((x ** 2 ) * (y ** 2 ))
825+
826+ @functools .partial (
827+ jax .jit ,
828+ in_shardings = (host_sharding , device_sharding ),
829+ out_shardings = (host_sharding , device_sharding ),
830+ donate_argnums = (0 ),
831+ )
832+ def device_func (host_data , device_data ):
833+ host_data , device_data = host_func (host_data , device_data )
834+ device_data = device_data * 2
835+ host_data , device_data = host_func (host_data , device_data )
836+ return (host_data , device_data )
837+
838+ input_x = jnp .ones (8 )
839+ input_host = jax .device_put (input_x , host_sharding )
840+
841+ input_device = jnp .arange (8 )
842+ input_device = jnp .where (input_device < 4 , 0 , 1 )
843+ input_device = jax .device_put (input_device , device_sharding )
844+
845+ output_host , output_device = device_func (input_host , input_device )
846+ self .assertEqual (output_host .sharding .memory_kind , 'pinned_host' )
847+ self .assertEqual (output_device .sharding .memory_kind , 'device' )
848+ self .assertArraysEqual (output_host , [0. , 0. , 0. , 0. , 2. , 2. , 2. , 2. ])
849+ self .assertArraysEqual (output_device , [0. , 0. , 0. , 0. , 4. , 4. , 4. , 4. ])
850+
811851 def test_compute_on_basic_inline (self ):
812852 @compute_on ('device_host' )
813853 @jax .jit
0 commit comments