Skip to content

Commit 0e7f218

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Support axis_index inside shard_map(auto=...) by using iota and
then calling full_to_shard. PiperOrigin-RevId: 705704369
1 parent 1453a22 commit 0e7f218

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

jax/_src/lax/parallel.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,20 +1591,33 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
15911591
axis_name, = axis_name
15921592
if axis_name not in axis_env.names:
15931593
raise NameError(f"unbound axis name: {axis_name}")
1594+
axis_context = ctx.module_context.axis_context
15941595
axis_pos = list(axis_env.names).index(axis_name)
1596+
1597+
# For partial auto, lower using iota.
1598+
if (isinstance(axis_context, SPMDAxisContext) and
1599+
axis_context.manual_axes and
1600+
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
1601+
x = hlo.iota(ir.RankedTensorType.get(
1602+
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
1603+
sharding_proto = (
1604+
NamedSharding(axis_context.mesh, P(axis_name))
1605+
._to_xla_hlo_sharding(1).to_proto())
1606+
aval_in = ShapedArray((axis_env.sizes[axis_pos],), np.int32)
1607+
aval_out = ShapedArray((1,), np.int32)
1608+
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto)
1609+
proto = pxla.manual_proto(aval_in, axis_context.manual_axes, axis_context.mesh)
1610+
x = mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto)
1611+
return hlo.reshape(ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), x)
1612+
15951613
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
15961614
div = mlir.ir_constant(
15971615
np.array(
15981616
nreplicas * math.prod(axis_env.sizes[axis_pos + 1 :]), dtype=np.uint32
15991617
)
16001618
)
16011619
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
1602-
axis_context = ctx.module_context.axis_context
1603-
is_spmd = isinstance(
1604-
axis_context,
1605-
(SPMDAxisContext, ShardingContext),
1606-
)
1607-
if is_spmd:
1620+
if isinstance(axis_context, (ShardingContext, SPMDAxisContext)):
16081621
device_id = hlo.partition_id()
16091622
else:
16101623
device_id = hlo.replica_id()

tests/shard_map_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,22 @@ def f():
21472147

21482148
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
21492149

2150+
def test_partial_auto_axis_index(self):
2151+
if config.use_shardy_partitioner.value:
2152+
self.skipTest('Shardy does not support full-to-shard.')
2153+
2154+
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
2155+
out_sharding = NamedSharding(mesh, P('i', None))
2156+
2157+
@partial(jax.jit, out_shardings=out_sharding)
2158+
def f():
2159+
return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1),
2160+
mesh, in_specs=P('i', None), out_specs=P('i', None),
2161+
check_rep=False, auto=frozenset({'j'}))()
2162+
2163+
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
2164+
2165+
21502166
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
21512167
# https://github.com/jax-ml/jax/pull/21032
21522168
mesh = jtu.create_mesh((4, 2), ('i', 'j'))

0 commit comments

Comments
 (0)