You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am migrating a library (jaxDecomp) from the legacy infer_sharding_from_operands callback to the new sharding_rule API compatible with Shardy.
My library implements distributed FFTs. The critical logic relies on inspecting the input sharding (specifically the PartitionSpec of the input) to decide:
Which algorithm to use (e.g., Slab decomposition vs. Pencil decomposition).
What the output sharding will look like (the algorithm effectively rotates the sharding axes).
The "Old" Way (Working)
Previously, infer_sharding_from_operands provided arg_infos populated with NamedSharding. I could inspect the input spec at compile time to dynamically determine the output spec.
@spmd_fft_primitive.def_infer_shardingdefinfer_sharding_from_operands(mesh, arg_infos, result_infos):
# 1. Access input shardinginput_sharding=arg_infos[0].shardingspec=input_sharding.spec# 2. Logic: Depending on which axis is sharded, the output spec changes# e.g., if input is sharded on Z, output must be sharded on Y (Slab XY algo)# e.g., if input is sharded on X, output must be sharded on Z (Slab YZ algo)pencil_type=get_pencil_type(spec)
transposed_specs=get_output_specs(pencil_type, spec)
returnNamedSharding(mesh, P(*transposed_specs))
The "New" Way (The Problem)
In the new API, sharding_rule is required for Shardy propagation. However, arg_infos passed to this callback only contains ranked shapes (ShapeDtypeStruct), not the sharding.
@spmd_fft_primitive.def_sharding_ruledeffft_sharding_rule_producer(mesh, arg_infos, result_infos):
# arg_infos[0] is just a ShapeDtypeStruct. # I cannot access .sharding to check which axis is distributed!# I need to return an Einsum or SdyShardingRule here, but I don't know # which rule to return because I don't know the input layout.# If I return a generic "i j k -> i j k", it fails because my custom_op # inherently performs a global transpose (reshuffle) that changes the sharding.return ???
The Question
My operation is polymorphic: the relationship between input and output dimensions depends entirely on how the input is currently distributed.
If sharding_rule is supposed to be purely declarative (agnostic of input sharding), how should we handle ops where the propagation rule itself depends on the input layout?
Is the recommended pattern to:
Resolve the sharding eagerly in Python (outside the primitive), determine the "mode" (Slab/Pencil), and pass that mode as a static argument to the primitive?
Or is there a way to define a SdyShardingRule that can express "If input dim 2 is sharded, output dim 1 becomes sharded"?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am migrating a library (
jaxDecomp) from the legacyinfer_sharding_from_operandscallback to the newsharding_ruleAPI compatible with Shardy.My library implements distributed FFTs. The critical logic relies on inspecting the input sharding (specifically the
PartitionSpecof the input) to decide:The "Old" Way (Working)
Previously,
infer_sharding_from_operandsprovidedarg_infospopulated withNamedSharding. I could inspect the input spec at compile time to dynamically determine the output spec.The "New" Way (The Problem)
In the new API,
sharding_ruleis required for Shardy propagation. However,arg_infospassed to this callback only contains ranked shapes (ShapeDtypeStruct), not the sharding.The Question
My operation is polymorphic: the relationship between input and output dimensions depends entirely on how the input is currently distributed.
If
sharding_ruleis supposed to be purely declarative (agnostic of input sharding), how should we handle ops where the propagation rule itself depends on the input layout?Is the recommended pattern to:
SdyShardingRulethat can express "If input dim 2 is sharded, output dim 1 becomes sharded"?Thanks!
Beta Was this translation helpful? Give feedback.
All reactions