Skip to content

Commit 0fb5974

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Support tuples in custom_partitioning.
PiperOrigin-RevId: 738154413
1 parent 080804c commit 0fb5974

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

jax/_src/custom_partitioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
179179
for sharding, s in zip(result_shardings, result_shapes)
180180
]
181181
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
182-
*tiled_args
182+
*info.in_tree.unflatten(tiled_args)
183183
)
184184
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
185185
[(t.shape, t.dtype) for t in tiled_results]):

tests/pjit_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,47 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
16801680
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
16811681
self.assertArraysEqual(x, jit_f(x))
16821682

1683+
@jtu.with_mesh([('x', 4), ('y', 2)])
1684+
def test_custom_partitioner_pytree_inputs(self):
1685+
self.skip_if_custom_partitioning_not_supported()
1686+
1687+
def partition(mesh, arg_shapes, result_shape):
1688+
def lower_fn(xs):
1689+
x, y, z = xs
1690+
return x + y + z
1691+
1692+
return (
1693+
mesh,
1694+
lower_fn,
1695+
arg_shapes[0][0].sharding,
1696+
jax.tree.map(lambda x: x.sharding, arg_shapes),
1697+
)
1698+
1699+
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
1700+
return arg_shapes[0][0].sharding
1701+
1702+
def propagate_user_sharding(mesh, user_shape):
1703+
return user_shape.sharding
1704+
1705+
@custom_partitioning
1706+
def f(xs):
1707+
x, y, z = xs
1708+
return x + y + z
1709+
1710+
f.def_partition(
1711+
infer_sharding_from_operands=infer_sharding_from_operands,
1712+
partition=partition,
1713+
propagate_user_sharding=propagate_user_sharding,
1714+
sharding_rule='i j, i j, i j -> i j',
1715+
)
1716+
1717+
def f2(a):
1718+
return a + f((a, a, a))
1719+
1720+
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
1721+
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
1722+
self.assertArraysEqual(x * 4, pjit_f(x))
1723+
16831724

16841725
@jtu.pytest_mark_if_available('multiaccelerator')
16851726
class AutoShardingPjitTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)