Skip to content

Commit 9041b02

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Account for tokens in allow_spmd_sharding_propagation_to_parameters and allow_spmd_sharding_propagation_to_output compile options
PiperOrigin-RevId: 707723232
1 parent 46b18d2 commit 9041b02

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2909,9 +2909,9 @@ def from_hlo(name: str,
29092909
da = _create_da_object(tuple(device_assignment))
29102910
del device_assignment
29112911

2912-
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
2913-
for i in in_shardings)
2914-
allow_prop_to_outputs = tuple(
2912+
allow_prop_to_inputs = (False,) * len(ordered_effects) + tuple(
2913+
isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings)
2914+
allow_prop_to_outputs = (False,) * len(ordered_effects) + tuple(
29152915
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
29162916
for o in out_shardings)
29172917

tests/pjit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4305,6 +4305,24 @@ def f(x):
43054305
self.assertLen(traced.in_avals[0], 1)
43064306
self.assertLen(traced.in_avals[1], 0) # empty kwarg
43074307

4308+
def test_empty_io_callback_under_shard_map(self):
4309+
if config.use_shardy_partitioner.value:
4310+
self.skipTest("Shardy errors out on empty callbacks.")
4311+
mesh = jtu.create_mesh((4,), 'i')
4312+
4313+
def empty_callback(x):
4314+
return
4315+
4316+
def _f(x, y):
4317+
jax.experimental.io_callback(
4318+
empty_callback, (), x, ordered=True)
4319+
return x + y[..., jnp.newaxis]
4320+
4321+
f = jax.jit(shard_map(
4322+
_f, mesh, in_specs=(P(None, 'i'), P(None)),
4323+
out_specs=P(None, 'i')))
4324+
f(jnp.zeros((2, 16)), jnp.ones(2))
4325+
43084326
def test_jit_trace_lower_and_compile(self):
43094327
def f(x):
43104328
return x * 2

0 commit comments

Comments
 (0)