Skip to content

Commit c22bed5

Browse files
authored
Merge pull request #197 from ROCm/ci-upstream-sync-83_1
CI: 01/09/25 upstream sync
2 parents 90eab82 + 9d34a49 commit c22bed5

33 files changed

+1275
-173
lines changed

WORKSPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python_init_repositories(
1313
"3.11": "//build:requirements_lock_3_11.txt",
1414
"3.12": "//build:requirements_lock_3_12.txt",
1515
"3.13": "//build:requirements_lock_3_13.txt",
16+
"3.13-ft": "//build:requirements_lock_3_13_ft.txt",
1617
},
1718
local_wheel_inclusion_list = [
1819
"jaxlib*",

build/build.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def add_global_arguments(parser: argparse.ArgumentParser):
6969
parser.add_argument(
7070
"--python_version",
7171
type=str,
72-
choices=["3.10", "3.11", "3.12", "3.13"],
7372
default=f"{sys.version_info.major}.{sys.version_info.minor}",
7473
help=
7574
"""
@@ -390,10 +389,23 @@ async def main():
390389
bazel_command_base.append("run")
391390

392391
if args.python_version:
392+
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
393+
# if bazel_options override it
394+
python_version_opt = "--repo_env=HERMETIC_PYTHON_VERSION="
395+
if any([python_version_opt in opt for opt in args.bazel_options]):
396+
raise RuntimeError(
397+
"Please use python_version to set hermetic python version instead of "
398+
"setting --repo_env=HERMETIC_PYTHON_VERSION=<python version> bazel option"
399+
)
393400
logging.debug("Hermetic Python version: %s", args.python_version)
394401
bazel_command_base.append(
395402
f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}"
396403
)
404+
# Let's interpret X.YY-ft version as free-threading python and set rules_python config flag:
405+
if args.python_version.endswith("-ft"):
406+
bazel_command_base.append(
407+
"--@rules_python//python/config_settings:py_freethreaded='yes'"
408+
)
397409

398410
# Enable verbose failures.
399411
bazel_command_base.append("--verbose_failures=true")

build/requirements_lock_3_13_ft.txt

Lines changed: 711 additions & 0 deletions
Large diffs are not rendered by default.

jax/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ py_library(
170170
] + jax_internal_export_back_compat_test_util_visibility,
171171
deps = [
172172
":jax",
173+
":test_util",
173174
] + py_deps("numpy"),
174175
)
175176

@@ -866,6 +867,7 @@ pytype_strict_library(
866867
":partition_spec",
867868
":sharding",
868869
":sharding_specs",
870+
":source_info_util",
869871
":tree_util",
870872
":util",
871873
":xla_bridge",

jax/_src/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def trace_context():
233233
class NoDefault: pass
234234
no_default = NoDefault()
235235

236+
config_states = {}
237+
236238
class State(config_ext.Config[_T]):
237239

238240
__slots__ = (
@@ -265,6 +267,7 @@ def __init__(
265267
self._validator(default)
266268
if self._update_global_hook:
267269
self._update_global_hook(default)
270+
config_states[name] = self
268271

269272
def __bool__(self) -> NoReturn:
270273
raise TypeError(

jax/_src/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,9 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None):
16961696
self.weak_type = weak_type
16971697
if config.sharding_in_types.value:
16981698
self.sharding = get_sharding(sharding, len(self.shape))
1699+
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
1700+
raise ValueError(
1701+
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")
16991702

17001703
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
17011704
if shape is None:

jax/_src/custom_partitioning.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
import jax
2929
from jax import tree_util
3030
from jax._src import api_util
31+
from jax._src import config
3132
from jax._src import core
3233
from jax._src import custom_api_util
3334
from jax._src import dispatch
3435
from jax._src import linear_util as lu
3536
from jax._src import mesh as mesh_lib
3637
from jax._src import sharding_impls
3738
from jax._src import xla_bridge as xb
39+
from jax._src.custom_partitioning_sharding_rule import sdy_sharding_rule_to_mlir, SdyShardingRule, str_to_sdy_sharding_rule
3840
from jax._src.interpreters import mlir
3941
from jax._src.interpreters import partial_eval as pe
4042
from jax._src.lib import xla_client as xc
@@ -225,18 +227,20 @@ def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
225227
propagate_user_sharding, partition,
226228
infer_sharding_from_operands,
227229
decode_shardings,
230+
sharding_rule,
228231
static_args):
229232
del in_tree, out_tree, propagate_user_sharding, partition
230-
del infer_sharding_from_operands, decode_shardings, static_args
233+
del infer_sharding_from_operands, decode_shardings, sharding_rule
234+
del static_args
231235
return call.out_avals
232236

233237

234238
def _custom_partitioning_impl(*args, call, in_tree, out_tree,
235239
propagate_user_sharding,
236240
partition, infer_sharding_from_operands,
237-
decode_shardings, static_args):
241+
decode_shardings, sharding_rule, static_args):
238242
del in_tree, out_tree, propagate_user_sharding, partition
239-
del infer_sharding_from_operands, decode_shardings, static_args
243+
del infer_sharding_from_operands, decode_shardings, static_args, sharding_rule
240244
return core.jaxpr_as_fun(call)(*args)
241245

242246

@@ -281,7 +285,14 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape):
281285
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
282286
283287
284-
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
288+
f.def_partition(partition, propagate_user_sharding,
289+
infer_sharding_from_operands=infer_sharding_from_operands,
290+
sharding_rule='i j -> 'i j')
291+
When config.use_shardy_partitioner.value is True, the sharding_rule is
292+
used; otherwise, propagate_user_sharding and infer_sharding_from_operands
293+
are used.
294+
Instead of using an Einsum-like notation string, sharding_rule can also be
295+
a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).
285296
286297
The args to ``def_partition`` are as follows:
287298
@@ -298,6 +309,10 @@ def infer_sharding_from_operands(mesh, arg_shapes, shape):
298309
* ``decode_shardings``: When set to True, convert input ``GSPMDSharding``s to
299310
``NamedSharding`` if possible. This may not be possible if the user does not
300311
provide a contextual mesh.
312+
* ``sharding_rule``: Either an SdyShardingRule object or an Einsum-like
313+
notation string that describes the sharding rule. We borrow the idea from
314+
the einops.rearrange string , to use a space separator between factors and
315+
allow multiple letters factor names.
301316
302317
Positional arguments can be specified as static using static_argnums. JAX uses
303318
:code:`inspect.signature(fun)` to resolve these positional arguments.
@@ -350,9 +365,16 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
350365
def my_fft(x):
351366
return fft(x)
352367
368+
# Use Einsum-like notation to specify the sharding rule.
353369
my_fft.def_partition(
354-
infer_sharding_from_operands=infer_sharding_from_operands,
355-
partition=partition)
370+
infer_sharding_from_operands=infer_sharding_from_operands,
371+
partition=partition,
372+
sharding_rule='...i -> ...i')
373+
# Use SdyShardingRule object to specify the sharding rule.
374+
my_fft.def_partition(
375+
infer_sharding_from_operands=infer_sharding_from_operands,
376+
partition=partition,
377+
sharding_rule=SdyShardingRule(operand_mappings=((SDY_BATCHING, 'i'),), result_mappings=((SDY_BATCHING, 'i'),))))
356378
357379
Now create a 2D array sharded along the first axis, pass it through ``my_fft``
358380
and notice how it is still sharded as expected, and identical to the output
@@ -425,15 +447,25 @@ def __init__(self, fun, static_argnums=()):
425447
self.static_argnums = static_argnums
426448
self.propagate_user_sharding = None
427449
self.infer_sharding_from_operands = None
450+
self.sharding_rule = None
428451

429452
__getattr__: Any = custom_api_util.forward_attr
430453

431454
def def_partition(self, partition, infer_sharding_from_operands,
432-
propagate_user_sharding=None, decode_shardings=True):
455+
propagate_user_sharding=None, decode_shardings=True,
456+
sharding_rule=None):
457+
if config.use_shardy_partitioner.value:
458+
infer_sharding_from_operands = None
459+
propagate_user_sharding = None
460+
else:
461+
sharding_rule = None
433462
self.partition = partition
434463
self.propagate_user_sharding = propagate_user_sharding
435464
self.infer_sharding_from_operands = infer_sharding_from_operands
436465
self.decode_shardings = decode_shardings
466+
self.sharding_rule = None if sharding_rule is None \
467+
else sharding_rule if isinstance(sharding_rule, SdyShardingRule) \
468+
else str_to_sdy_sharding_rule(sharding_rule)
437469
return partition
438470

439471
def __call__(self, *args, **kwargs):
@@ -471,6 +503,7 @@ def __call__(self, *args, **kwargs):
471503
propagate_user_sharding=self.propagate_user_sharding,
472504
infer_sharding_from_operands=self.infer_sharding_from_operands,
473505
decode_shardings=self.decode_shardings,
506+
sharding_rule=self.sharding_rule,
474507
in_tree=in_tree,
475508
out_tree=out_tree(),
476509
static_args=static_args
@@ -483,6 +516,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
483516
propagate_user_sharding, partition,
484517
infer_sharding_from_operands,
485518
decode_shardings,
519+
sharding_rule,
486520
static_args):
487521
axis_context = ctx.module_context.axis_context
488522
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
@@ -539,6 +573,9 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim):
539573
backend_config=ir.StringAttr.get(key),
540574
operand_layouts=None,
541575
result_layouts=None)
576+
if sharding_rule is not None:
577+
value_types = [mlir.aval_to_ir_type(s) for s in call.in_avals]
578+
out.attributes['sdy.sharding_rule'] = sdy_sharding_rule_to_mlir(sharding_rule, value_types, result_types)
542579
return out.results
543580

544581
mlir.register_lowering(custom_partitioning_p,

jax/_src/custom_partitioning_sharding_rule.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
# leading ... into factors.
2828
_BATCHING_DIM_FACTOR_PREFIX = "?"
2929

30+
# A Jax value in general corresponds to an ir.Type or a tuple of ir.Types.
31+
IrTypes = ir.Type | tuple[ir.Type, ...]
32+
3033
def _check_factor(factor:str):
3134
"""Validates a factor.
3235
@@ -278,8 +281,8 @@ def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule:
278281

279282
def sdy_sharding_rule_to_mlir(
280283
rule: SdyShardingRule,
281-
operand_types: list[ir.Type],
282-
result_types: list[ir.Type],) -> ir.Attribute:
284+
operand_types: list[IrTypes],
285+
result_types: list[IrTypes],) -> ir.Attribute:
283286
"""Builds the MLIR representation for the sharding rule.
284287
285288
This is done by verifying that the rule is consistent with the types of
@@ -294,6 +297,10 @@ def sdy_sharding_rule_to_mlir(
294297
raise ValueError(
295298
f"Sharding rule has {len(rule.result_mappings)} results, but the operation"
296299
f" has {len(result_types)} results")
300+
if not all(isinstance(t, ir.Type) for t in operand_types + result_types):
301+
raise TypeError(
302+
f"operand_types and result_types must be a list of ir.Type, but got"
303+
f" {operand_types} and {result_types}")
297304

298305
factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict()
299306
types = operand_types + result_types

jax/_src/interpreters/mlir.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
from jax._src.interpreters import xla
5050
from jax._src.layout import AutoLayout, DeviceLocalLayout
5151
from jax._src.sharding import Sharding as JSharding
52-
from jax._src.sharding_impls import AUTO, NamedSharding
52+
from jax._src.sharding_impls import (AUTO, NamedSharding,
53+
modify_sdy_sharding_wrt_axis_types)
5354
from jax._src.lib import xla_client as xc
5455
from jax._src.lib import xla_extension
5556
from jax._src.lib.mlir import dialects, ir, passmanager
@@ -1689,13 +1690,17 @@ def lower_jaxpr_to_fun(
16891690
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
16901691

16911692
if ir_result_shardings is not None:
1692-
flat_outputs = [
1693-
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
1694-
unspecified_dims=us[2])
1695-
if us[0] and not us[1] else o
1696-
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
1697-
output_avals, unconstrained_shardings) # type: ignore
1698-
]
1693+
temp_flat_outputs = []
1694+
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
1695+
output_avals, unconstrained_shardings): # type: ignore
1696+
if us[0] and not us[1]:
1697+
if config.use_shardy_partitioner.value and config.sharding_in_types.value:
1698+
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
1699+
temp_flat_outputs.append(wrap_with_sharding_op(
1700+
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
1701+
else:
1702+
temp_flat_outputs.append(o)
1703+
flat_outputs = temp_flat_outputs
16991704

17001705
# Insert a custom call if output is on host because XLA needs that to do the
17011706
# transfer.
@@ -2594,14 +2599,20 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
25942599
return op
25952600
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
25962601
# `return op` early and avoid bloating HLO size.
2597-
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
2598-
if sharding_proto is None else sharding_proto)
2599-
unspecified_dims = None
2600-
if aval.sharding.mesh._any_axis_collective:
2601-
unspecified_dims = set(range(aval.ndim))
2602-
elif aval.sharding.mesh._any_axis_auto:
2603-
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
2604-
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
2602+
if config.use_shardy_partitioner.value:
2603+
proto = (aval.sharding._to_sdy_sharding(aval.ndim)
2604+
if sharding_proto is None else sharding_proto)
2605+
proto = modify_sdy_sharding_wrt_axis_types(proto, aval.sharding.mesh)
2606+
return wrap_with_sharding_op(ctx, op, aval, proto)
2607+
else:
2608+
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
2609+
if sharding_proto is None else sharding_proto)
2610+
unspecified_dims = None
2611+
if aval.sharding.mesh._any_axis_auto:
2612+
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes
2613+
# as unspecified?
2614+
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
2615+
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
26052616

26062617

26072618
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):

jax/_src/interpreters/pxla.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,12 +2163,9 @@ def _abstract_to_concrete_mesh(abstract_mesh):
21632163
out = []
21642164
for s, a in zip(shardings, avals):
21652165
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
2166-
if config.use_shardy_partitioner.value:
2167-
spec = a.sharding.spec
2168-
else:
2169-
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
2170-
for sp in a.sharding.spec])
2171-
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
2166+
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
2167+
for sp in a.sharding.spec])
2168+
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
21722169
out.append(NamedSharding(
21732170
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
21742171
else:

0 commit comments

Comments
 (0)