2828import jax
2929from jax import tree_util
3030from jax ._src import api_util
31+ from jax ._src import config
3132from jax ._src import core
3233from jax ._src import custom_api_util
3334from jax ._src import dispatch
3435from jax ._src import linear_util as lu
3536from jax ._src import mesh as mesh_lib
3637from jax ._src import sharding_impls
3738from jax ._src import xla_bridge as xb
39+ from jax ._src .custom_partitioning_sharding_rule import sdy_sharding_rule_to_mlir , SdyShardingRule , str_to_sdy_sharding_rule
3840from jax ._src .interpreters import mlir
3941from jax ._src .interpreters import partial_eval as pe
4042from jax ._src .lib import xla_client as xc
@@ -225,18 +227,20 @@ def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
225227 propagate_user_sharding , partition ,
226228 infer_sharding_from_operands ,
227229 decode_shardings ,
230+ sharding_rule ,
228231 static_args ):
229232 del in_tree , out_tree , propagate_user_sharding , partition
230- del infer_sharding_from_operands , decode_shardings , static_args
233+ del infer_sharding_from_operands , decode_shardings , sharding_rule
234+ del static_args
231235 return call .out_avals
232236
233237
234238def _custom_partitioning_impl (* args , call , in_tree , out_tree ,
235239 propagate_user_sharding ,
236240 partition , infer_sharding_from_operands ,
237- decode_shardings , static_args ):
241+ decode_shardings , sharding_rule , static_args ):
238242 del in_tree , out_tree , propagate_user_sharding , partition
239- del infer_sharding_from_operands , decode_shardings , static_args
243+ del infer_sharding_from_operands , decode_shardings , static_args , sharding_rule
240244 return core .jaxpr_as_fun (call )(* args )
241245
242246
@@ -281,7 +285,14 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape):
281285 arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
282286
283287
284- f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
288+ f.def_partition(partition, propagate_user_sharding,
289+ infer_sharding_from_operands=infer_sharding_from_operands,
290+ sharding_rule='i j -> 'i j')
291+ When config.use_shardy_partitioner.value is True, the sharding_rule is
292+ used; otherwise, propagate_user_sharding and infer_sharding_from_operands
293+ are used.
294+ Instead of using an Einsum-like notation string, sharding_rule can also be
295+ a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).
285296
286297 The args to ``def_partition`` are as follows:
287298
@@ -298,6 +309,10 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape):
298309 * ``decode_shardings``: When set to True, convert input ``GSPMDSharding``s to
299310 ``NamedSharding`` if possible. This may not be possible if the user does not
300311 provide a contextual mesh.
312+ * ``sharding_rule``: Either an SdyShardingRule object or an Einsum-like
313+ notation string that describes the sharding rule. We borrow the idea from
314+ the einops.rearrange string , to use a space separator between factors and
315+ allow multiple letters factor names.
301316
302317 Positional arguments can be specified as static using static_argnums. JAX uses
303318 :code:`inspect.signature(fun)` to resolve these positional arguments.
@@ -350,9 +365,16 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
350365 def my_fft(x):
351366 return fft(x)
352367
368+ # Use Einsum-like notation to specify the sharding rule.
353369 my_fft.def_partition(
354- infer_sharding_from_operands=infer_sharding_from_operands,
355- partition=partition)
370+ infer_sharding_from_operands=infer_sharding_from_operands,
371+ partition=partition,
372+ sharding_rule='...i -> ...i')
373+ # Use SdyShardingRule object to specify the sharding rule.
374+ my_fft.def_partition(
375+ infer_sharding_from_operands=infer_sharding_from_operands,
376+ partition=partition,
377+ sharding_rule=SdyShardingRule(operand_mappings=((SDY_BATCHING, 'i'),), result_mappings=((SDY_BATCHING, 'i'),))))
356378
357379 Now create a 2D array sharded along the first axis, pass it through ``my_fft``
358380 and notice how it is still sharded as expected, and identical to the output
@@ -425,15 +447,25 @@ def __init__(self, fun, static_argnums=()):
425447 self .static_argnums = static_argnums
426448 self .propagate_user_sharding = None
427449 self .infer_sharding_from_operands = None
450+ self .sharding_rule = None
428451
429452 __getattr__ : Any = custom_api_util .forward_attr
430453
431454 def def_partition (self , partition , infer_sharding_from_operands ,
432- propagate_user_sharding = None , decode_shardings = True ):
455+ propagate_user_sharding = None , decode_shardings = True ,
456+ sharding_rule = None ):
457+ if config .use_shardy_partitioner .value :
458+ infer_sharding_from_operands = None
459+ propagate_user_sharding = None
460+ else :
461+ sharding_rule = None
433462 self .partition = partition
434463 self .propagate_user_sharding = propagate_user_sharding
435464 self .infer_sharding_from_operands = infer_sharding_from_operands
436465 self .decode_shardings = decode_shardings
466+ self .sharding_rule = None if sharding_rule is None \
467+ else sharding_rule if isinstance (sharding_rule , SdyShardingRule ) \
468+ else str_to_sdy_sharding_rule (sharding_rule )
437469 return partition
438470
439471 def __call__ (self , * args , ** kwargs ):
@@ -471,6 +503,7 @@ def __call__(self, *args, **kwargs):
471503 propagate_user_sharding = self .propagate_user_sharding ,
472504 infer_sharding_from_operands = self .infer_sharding_from_operands ,
473505 decode_shardings = self .decode_shardings ,
506+ sharding_rule = self .sharding_rule ,
474507 in_tree = in_tree ,
475508 out_tree = out_tree (),
476509 static_args = static_args
@@ -483,6 +516,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
483516 propagate_user_sharding , partition ,
484517 infer_sharding_from_operands ,
485518 decode_shardings ,
519+ sharding_rule ,
486520 static_args ):
487521 axis_context = ctx .module_context .axis_context
488522 if (isinstance (axis_context , sharding_impls .SPMDAxisContext ) and
@@ -539,6 +573,9 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim):
539573 backend_config = ir .StringAttr .get (key ),
540574 operand_layouts = None ,
541575 result_layouts = None )
576+ if sharding_rule is not None :
577+ value_types = [mlir .aval_to_ir_type (s ) for s in call .in_avals ]
578+ out .attributes ['sdy.sharding_rule' ] = sdy_sharding_rule_to_mlir (sharding_rule , value_types , result_types )
542579 return out .results
543580
544581mlir .register_lowering (custom_partitioning_p ,
0 commit comments