@@ -1680,6 +1680,47 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
16801680 jit_f = jax .jit (f , in_shardings = s , out_shardings = s )
16811681 self .assertArraysEqual (x , jit_f (x ))
16821682
1683+ @jtu .with_mesh ([('x' , 4 ), ('y' , 2 )])
1684+ def test_custom_partitioner_pytree_inputs (self ):
1685+ self .skip_if_custom_partitioning_not_supported ()
1686+
1687+ def partition (mesh , arg_shapes , result_shape ):
1688+ def lower_fn (xs ):
1689+ x , y , z = xs
1690+ return x + y + z
1691+
1692+ return (
1693+ mesh ,
1694+ lower_fn ,
1695+ arg_shapes [0 ][0 ].sharding ,
1696+ jax .tree .map (lambda x : x .sharding , arg_shapes ),
1697+ )
1698+
1699+ def infer_sharding_from_operands (mesh , arg_shapes , result_shape ):
1700+ return arg_shapes [0 ][0 ].sharding
1701+
1702+ def propagate_user_sharding (mesh , user_shape ):
1703+ return user_shape .sharding
1704+
1705+ @custom_partitioning
1706+ def f (xs ):
1707+ x , y , z = xs
1708+ return x + y + z
1709+
1710+ f .def_partition (
1711+ infer_sharding_from_operands = infer_sharding_from_operands ,
1712+ partition = partition ,
1713+ propagate_user_sharding = propagate_user_sharding ,
1714+ sharding_rule = 'i j, i j, i j -> i j' ,
1715+ )
1716+
1717+ def f2 (a ):
1718+ return a + f ((a , a , a ))
1719+
1720+ pjit_f = pjit (f2 , in_shardings = (P (None , 'x' )), out_shardings = P ('x' ))
1721+ x = np .asarray (np .random .randint (0 , 20 , (32 , 16 )), dtype = np .float32 )
1722+ self .assertArraysEqual (x * 4 , pjit_f (x ))
1723+
16831724
16841725@jtu .pytest_mark_if_available ('multiaccelerator' )
16851726class AutoShardingPjitTest (jtu .JaxTestCase ):
0 commit comments