Skip to content

Commit c4ac0dd

Browse files
bixia1Google-ML-Automation
authored andcommitted
Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to provide a sharding rule specification when Shardy is used. We use this information to construct a SdyShardingRule and invoke SdyShardingRule.build during MLIR lowering. Extend custom_partitioner tests in pjit_test.py for Shardy sharding rule. PiperOrigin-RevId: 713399604
1 parent b3833dc commit c4ac0dd

File tree

4 files changed

+70
-16
lines changed

4 files changed

+70
-16
lines changed

jax/_src/custom_partitioning.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
import jax
2929
from jax import tree_util
3030
from jax._src import api_util
31+
from jax._src import config
3132
from jax._src import core
3233
from jax._src import custom_api_util
3334
from jax._src import dispatch
3435
from jax._src import linear_util as lu
3536
from jax._src import mesh as mesh_lib
3637
from jax._src import sharding_impls
3738
from jax._src import xla_bridge as xb
39+
from jax._src.custom_partitioning_sharding_rule import sdy_sharding_rule_to_mlir, SdyShardingRule, str_to_sdy_sharding_rule
3840
from jax._src.interpreters import mlir
3941
from jax._src.interpreters import partial_eval as pe
4042
from jax._src.lib import xla_client as xc
@@ -225,18 +227,20 @@ def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
225227
propagate_user_sharding, partition,
226228
infer_sharding_from_operands,
227229
decode_shardings,
230+
sharding_rule,
228231
static_args):
229232
del in_tree, out_tree, propagate_user_sharding, partition
230-
del infer_sharding_from_operands, decode_shardings, static_args
233+
del infer_sharding_from_operands, decode_shardings, sharding_rule
234+
del static_args
231235
return call.out_avals
232236

233237

234238
def _custom_partitioning_impl(*args, call, in_tree, out_tree,
235239
propagate_user_sharding,
236240
partition, infer_sharding_from_operands,
237-
decode_shardings, static_args):
241+
decode_shardings, sharding_rule, static_args):
238242
del in_tree, out_tree, propagate_user_sharding, partition
239-
del infer_sharding_from_operands, decode_shardings, static_args
243+
del infer_sharding_from_operands, decode_shardings, static_args, sharding_rule
240244
return core.jaxpr_as_fun(call)(*args)
241245

242246

@@ -281,7 +285,14 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape):
281285
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
282286
283287
284-
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
288+
f.def_partition(partition, propagate_user_sharding,
289+
infer_sharding_from_operands=infer_sharding_from_operands,
290+
sharding_rule='i j -> 'i j')
291+
When config.use_shardy_partitioner.value is True, the sharding_rule is
292+
used; otherwise, propagate_user_sharding and infer_sharding_from_operands
293+
are used.
294+
Instead of using an Einsum-like notation string, sharding_rule can also be
295+
a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).
285296
286297
The args to ``def_partition`` are as follows:
287298
@@ -298,6 +309,10 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape):
298309
* ``decode_shardings``: When set to True, convert input ``GSPMDSharding``s to
299310
``NamedSharding`` if possible. This may not be possible if the user does not
300311
provide a contextual mesh.
312+
* ``sharding_rule``: Either an SdyShardingRule object or an Einsum-like
313+
notation string that describes the sharding rule. We borrow the idea from
314+
the einops.rearrange string , to use a space separator between factors and
315+
allow multiple letters factor names.
301316
302317
Positional arguments can be specified as static using static_argnums. JAX uses
303318
:code:`inspect.signature(fun)` to resolve these positional arguments.
@@ -350,9 +365,16 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
350365
def my_fft(x):
351366
return fft(x)
352367
368+
# Use Einsum-like notation to specify the sharding rule.
353369
my_fft.def_partition(
354-
infer_sharding_from_operands=infer_sharding_from_operands,
355-
partition=partition)
370+
infer_sharding_from_operands=infer_sharding_from_operands,
371+
partition=partition,
372+
sharding_rule='...i -> ...i')
373+
# Use SdyShardingRule object to specify the sharding rule.
374+
my_fft.def_partition(
375+
infer_sharding_from_operands=infer_sharding_from_operands,
376+
partition=partition,
377+
sharding_rule=SdyShardingRule(operand_mappings=((SDY_BATCHING, 'i'),), result_mappings=((SDY_BATCHING, 'i'),))))
356378
357379
Now create a 2D array sharded along the first axis, pass it through ``my_fft``
358380
and notice how it is still sharded as expected, and identical to the output
@@ -425,15 +447,25 @@ def __init__(self, fun, static_argnums=()):
425447
self.static_argnums = static_argnums
426448
self.propagate_user_sharding = None
427449
self.infer_sharding_from_operands = None
450+
self.sharding_rule = None
428451

429452
__getattr__: Any = custom_api_util.forward_attr
430453

431454
def def_partition(self, partition, infer_sharding_from_operands,
432-
propagate_user_sharding=None, decode_shardings=True):
455+
propagate_user_sharding=None, decode_shardings=True,
456+
sharding_rule=None):
457+
if config.use_shardy_partitioner.value:
458+
infer_sharding_from_operands = None
459+
propagate_user_sharding = None
460+
else:
461+
sharding_rule = None
433462
self.partition = partition
434463
self.propagate_user_sharding = propagate_user_sharding
435464
self.infer_sharding_from_operands = infer_sharding_from_operands
436465
self.decode_shardings = decode_shardings
466+
self.sharding_rule = None if sharding_rule is None \
467+
else sharding_rule if isinstance(sharding_rule, SdyShardingRule) \
468+
else str_to_sdy_sharding_rule(sharding_rule)
437469
return partition
438470

439471
def __call__(self, *args, **kwargs):
@@ -471,6 +503,7 @@ def __call__(self, *args, **kwargs):
471503
propagate_user_sharding=self.propagate_user_sharding,
472504
infer_sharding_from_operands=self.infer_sharding_from_operands,
473505
decode_shardings=self.decode_shardings,
506+
sharding_rule=self.sharding_rule,
474507
in_tree=in_tree,
475508
out_tree=out_tree(),
476509
static_args=static_args
@@ -483,6 +516,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
483516
propagate_user_sharding, partition,
484517
infer_sharding_from_operands,
485518
decode_shardings,
519+
sharding_rule,
486520
static_args):
487521
axis_context = ctx.module_context.axis_context
488522
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
@@ -539,6 +573,9 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim):
539573
backend_config=ir.StringAttr.get(key),
540574
operand_layouts=None,
541575
result_layouts=None)
576+
if sharding_rule is not None:
577+
value_types = [mlir.aval_to_ir_type(s) for s in call.in_avals]
578+
out.attributes['sdy.sharding_rule'] = sdy_sharding_rule_to_mlir(sharding_rule, value_types, result_types)
542579
return out.results
543580

544581
mlir.register_lowering(custom_partitioning_p,

jax/_src/custom_partitioning_sharding_rule.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
# leading ... into factors.
2828
_BATCHING_DIM_FACTOR_PREFIX = "?"
2929

30+
# A Jax value in general corresponds to an ir.Type or a tuple of ir.Types.
31+
IrTypes = ir.Type | tuple[ir.Type, ...]
32+
3033
def _check_factor(factor:str):
3134
"""Validates a factor.
3235
@@ -278,8 +281,8 @@ def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule:
278281

279282
def sdy_sharding_rule_to_mlir(
280283
rule: SdyShardingRule,
281-
operand_types: list[ir.Type],
282-
result_types: list[ir.Type],) -> ir.Attribute:
284+
operand_types: list[IrTypes],
285+
result_types: list[IrTypes],) -> ir.Attribute:
283286
"""Builds the MLIR representation for the sharding rule.
284287
285288
This is done by verifying that the rule is consistent with the types of
@@ -294,6 +297,10 @@ def sdy_sharding_rule_to_mlir(
294297
raise ValueError(
295298
f"Sharding rule has {len(rule.result_mappings)} results, but the operation"
296299
f" has {len(result_types)} results")
300+
if not all(isinstance(t, ir.Type) for t in operand_types + result_types):
301+
raise TypeError(
302+
f"operand_types and result_types must be a list of ir.Type, but got"
303+
f" {operand_types} and {result_types}")
297304

298305
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
299306
types = operand_types + result_types

jax/experimental/custom_partitioning.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,10 @@
1919
custom_partitioning as custom_partitioning,
2020
custom_partitioning_p as custom_partitioning_p,
2121
)
22+
23+
from jax._src.custom_partitioning_sharding_rule import (
24+
BATCHING as BATCHING,
25+
CompoundFactor as CompoundFactor,
26+
ArrayMapping as ArrayMapping,
27+
SdyShardingRule as SdyShardingRule,
28+
)

tests/pjit_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from jax.sharding import PartitionSpec as P, Mesh
4444
from jax.experimental import multihost_utils
4545
from jax.experimental.shard_map import shard_map
46-
from jax.experimental.custom_partitioning import custom_partitioning
46+
from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, BATCHING
4747
from jax._src import array
4848
from jax._src.sharding import Sharding, common_devices_indices_map
4949
from jax._src import op_shardings
@@ -1320,9 +1320,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
13201320
def skip_if_custom_partitioning_not_supported(self):
13211321
if jtu.is_cloud_tpu():
13221322
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
1323-
if config.use_shardy_partitioner.value:
1324-
self.skipTest(
1325-
'Custom partitioning is not supported with Shardy yet.')
13261323

13271324
@jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU.
13281325
@jtu.with_mesh([('x', 4), ('y', 2)])
@@ -1366,7 +1363,8 @@ def f(x, y, precision=None):
13661363

13671364
f.def_partition(
13681365
infer_sharding_from_operands=infer_sharding_from_operands,
1369-
partition=partition)
1366+
partition=partition,
1367+
sharding_rule=SdyShardingRule(operand_mappings=(('i', 'j'), ('j', 'k')), result_mappings=(('i', 'k'), ('i', 'k'))))
13701368

13711369
pjit_f = pjit(f, in_shardings=(P('x'), P('y')), out_shardings=P('x'))
13721370
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
@@ -1406,6 +1404,7 @@ def f(x):
14061404
infer_sharding_from_operands=infer_sharding_from_operands,
14071405
partition=partition,
14081406
propagate_user_sharding=propagate_user_sharding,
1407+
sharding_rule='i j -> i j',
14091408
)
14101409

14111410
def f2(a):
@@ -1442,7 +1441,7 @@ def f(x):
14421441
f.def_partition(
14431442
infer_sharding_from_operands=infer_sharding_from_operands,
14441443
partition=partition,
1445-
)
1444+
sharding_rule=SdyShardingRule(operand_mappings=((BATCHING, 'i'),), result_mappings=((BATCHING, 'i'),)))
14461445

14471446
pjit_f = pjit(f, in_shardings=(P(None, 'x')), out_shardings=P('x'))
14481447
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
@@ -1474,6 +1473,7 @@ def f(x):
14741473
f.def_partition(
14751474
infer_sharding_from_operands=infer_sharding_from_operands,
14761475
partition=partition,
1476+
sharding_rule='i j -> i j',
14771477
)
14781478

14791479
pjit_f = pjit(f, in_shardings=(P(None, 'x')), out_shardings=P('x'))
@@ -1520,6 +1520,7 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
15201520
f.def_partition(
15211521
infer_sharding_from_operands=infer_sharding_from_operands,
15221522
partition=partition,
1523+
sharding_rule='i -> i',
15231524
)
15241525

15251526
jit_f = jax.jit(f)
@@ -1552,7 +1553,8 @@ def f(carry, x):
15521553
f.def_partition(
15531554
partition,
15541555
infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()),
1555-
propagate_user_sharding=lambda _, user_shape: user_shape.sharding)
1556+
propagate_user_sharding=lambda _, user_shape: user_shape.sharding,
1557+
sharding_rule='i j -> ') # Result is a scalar.
15561558

15571559
pjit_f = pjit(f, in_shardings=P(None, 'x'))
15581560
xs = jnp.ones([32, 16])
@@ -1588,6 +1590,7 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
15881590
f.def_partition(
15891591
infer_sharding_from_operands=infer_sharding_from_operands,
15901592
partition=partition,
1593+
sharding_rule='i -> i',
15911594
)
15921595

15931596
mesh = jtu.create_mesh((4,), ('x',))

0 commit comments

Comments
 (0)