You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working on a hierarchical sharding setup (e.g., chain, x, y) and I want to confirm if my composition pattern is idiomatic.
My goal is to split the logic into two distinct layers:
Outer "Group" Parallelism (chain): I want to treat this strictly as a batch dimension where the operation is replicated across groups.
Inner "Spatial" Parallelism (x, y): Inside each group, I want to drop into shard_map to perform manual collective operations (like all_to_all or halo exchanges) specifically on that group's mesh.
Why not a single shard_map?
I specifically want to avoid writing a single monolithic shard_map that maps all axes (c, x, y) at once. My reasoning is that I don't want the inner kernel to be aware of the global chain coordinate or the "per-device" slice of the batch. I want to define the operation "per group" (i.e., defined on the x, y grid) and then simply broadcast/vectorize that logic over the chain axis.
I found that composing jax.vmap (to handle the chain axis) with an inner shard_map (to handle x, y) achieves this cleanly. vmap "peels off" the chain dimension, so the inner function only sees the spatial dimensions.
Here is my reproduction script:
importos# 1. Setup Environmentos.environ['XLA_FLAGS'] ='--xla_force_host_platform_device_count=16'os.environ['JAX_PLATFORM_NAME'] ='cpu'os.environ['JAX_PLATFORMS'] ='cpu'importjaximportjax.numpyasjnpfromjax.shardingimportMesh, PartitionSpecasP , NamedShardingfromjaximportshard_map, laximportnumpyasnpimportjax.debugasjdbgfromfunctoolsimportpartial# Mesh Definition: # c=2 (Chain/Batch), x=4 (Spatial X), y=2 (Spatial Y)# We view this as 2 "groups", where each group is a 4x2 grid.mesh=Mesh(np.array(jax.devices()[:16]).reshape(2, 4, 2 , 1), ('c', 'x', 'y' , 'f'))
# Atomic Function (The "Per Group" Logic):# This is defined purely for the Spatial Mesh ('x', 'y'). # It expects input shape (X_local, Y_local, Z) and knows nothing about 'chain'.@partial(shard_map , mesh=mesh, in_specs=P('x', 'y', None), out_specs=P('y', 'x', None))defatomic_op(arr):
# Proof: Print rank to show 'chain' is goneprint(f"ATOMIC OP (Inner): Rank={arr.ndim}, Shape={arr.shape}")
jax.debug.inspect_array_sharding(arr , callback=lambdasharding: print(f"Sharding in Atomic Op: {sharding}"))
# Manual Collective: Scramble X, then transpose# Note: We stick the result on axis 0, which is X in this local viewreturnlax.all_to_all(arr, axis_name='x', split_axis=0, concat_axis=2, tiled=True).transpose(1, 2, 0)
# Outer Function (The "Chain" Logic):# We use vmap to vectorize the atomic_op over the 'chain' axis.@jax.jit@jax.vmapdefvmapped_op(arr):
print(f"VMAP (Outer): Rank={arr.ndim}, Shape={arr.shape}")
jax.debug.inspect_array_sharding(arr , callback=lambdasharding: print(f"Sharding in VMAP: {sharding}"))
returnatomic_op(arr)
# Execution# Global Data: (Chain=2, X=32, Y=32, Z=4)global_data=jnp.zeros((2, 32, 32, 4))
global_data=jax.device_put(global_data, NamedSharding(mesh, P('c', 'x', 'y', 'f')))
print(f"Global Input Sharding: {global_data.sharding}")
# 1. Run the vmapped optransposed=vmapped_op(global_data)
# 2. Verify against a manual single-slice executionindx=1transposed_first=jax.jit(atomic_op)(global_data[indx])
print(f"Is vmapped result consistent with manual slice? {jnp.allclose(transposed[indx], transposed_first)}")
Question:
Is this vmap(shard_map(...)) pattern safe and idiomatic for this "Group-wise" operation model? Specifically, does JAX guarantee that the vmap correctly partitions the P('c') axis such that the inner shard_map operates on the correct subset of devices without communication overhead across the chain?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I am working on a hierarchical sharding setup (e.g.,
chain,x,y) and I want to confirm if my composition pattern is idiomatic.My goal is to split the logic into two distinct layers:
chain): I want to treat this strictly as a batch dimension where the operation is replicated across groups.x,y): Inside each group, I want to drop intoshard_mapto perform manual collective operations (likeall_to_allor halo exchanges) specifically on that group's mesh.Why not a single
shard_map?I specifically want to avoid writing a single monolithic
shard_mapthat maps all axes (c,x,y) at once. My reasoning is that I don't want the inner kernel to be aware of the globalchaincoordinate or the "per-device" slice of the batch. I want to define the operation "per group" (i.e., defined on thex, ygrid) and then simply broadcast/vectorize that logic over thechainaxis.I found that composing
jax.vmap(to handle thechainaxis) with an innershard_map(to handlex, y) achieves this cleanly.vmap"peels off" the chain dimension, so the inner function only sees the spatial dimensions.Here is my reproduction script:
Question:
Is this
vmap(shard_map(...))pattern safe and idiomatic for this "Group-wise" operation model? Specifically, does JAX guarantee that thevmapcorrectly partitions theP('c')axis such that the innershard_mapoperates on the correct subset of devices without communication overhead across thechain?Thank you for your help.
Beta Was this translation helpful? Give feedback.
All reactions