Skip to content

Commit d0f6b95

Browse files
authored
Merge pull request #136 from ROCm/ci-upstream-sync-16_1
CI: 11/12/24 upstream sync
2 parents 0b970b8 + f3f6446 commit d0f6b95

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+2086
-398
lines changed

.github/workflows/ci-build.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ jobs:
3535
with:
3636
python-version: 3.11
3737
- run: python -m pip install pre-commit
38-
- uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
38+
- uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
3939
with:
4040
path: ~/.cache/pre-commit
4141
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
42-
- run: pre-commit run --show-diff-on-failure --color=always
42+
- run: pre-commit run --show-diff-on-failure --color=always --all-files
4343

4444
build:
4545
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2525
for information on migrating to the new API.
2626
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
2727
has been removed, after being deprecated in v0.4.27.
28+
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
29+
now raises an error. Previously, this returned a scalar object array.
2830
* The following deprecated methods and functions in {mod}`jax.export` have
2931
been removed:
3032
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect

docs/about.md

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
(about-the-project)=
2+
3+
# About the project
4+
5+
The JAX project is led by the JAX core team. We develop in the open,
6+
and welcome open-source contributions from across the community. We
7+
frequently see contributions from [Google
8+
DeepMind](https://deepmind.google/), Alphabet more broadly,
9+
[NVIDIA](https://docs.nvidia.com/deeplearning/frameworks/jax-release-notes/overview.html),
10+
and elsewhere.
11+
12+
At the heart of the project is the [JAX
13+
core](http://github.com/google/jax) library, which focuses on the
14+
fundamentals of machine learning and numerical computing, at scale.
15+
16+
When [developing](#development) the core, we want to maintain agility
17+
and a focused scope, so we lean heavily on a surrounding [modular
18+
technology stack](#components). First, we design the `jax` module
19+
to be
20+
[composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations)
21+
and
22+
[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so
23+
that a wide variety of domain-specific libraries can thrive outside of
24+
it in a decentralized manner. Second, we lean heavily on a modular
25+
backend stack (compiler and runtime) to target different
26+
accelerators. Whether you are [writing a new domain-specific library
27+
built with JAX](#upstack), or looking to [support
28+
new hardware](#downstack), you can often
29+
contribute these with *minimal to no modifications* to the JAX core
30+
codebase.
31+
32+
Many of JAX's core contributors have roots in open-source software and
33+
in research, in fields spanning computer science and the natural
34+
sciences. We strive to continuously enable the cutting edge of machine
35+
learning and numerical computing---across all compute platforms and
36+
accelerators---and to discover the truths of array programming at
37+
scale.
38+
39+
(development)=
40+
## Open development
41+
42+
JAX's day-to-day development takes place in the open on GitHub, using
43+
pull requests, the issue tracker, discussions, and [JAX Enhancement
44+
Proposals
45+
(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading
46+
and participating in these is a good way to get involved. We also
47+
maintain [developer
48+
notes](https://jax.readthedocs.io/en/latest/contributor_guide.html)
49+
that cover JAX's internal design.
50+
51+
The JAX core team determines whether to accept changes and
52+
enhancements. Maintaining a simple decision-making structure currently
53+
helps us develop at the speed of the research frontier. Open
54+
development is a core value of ours, and we may adapt to a more
55+
intricate decision structure over time (e.g. with designated area
56+
owners) if/when it becomes useful to do so.
57+
58+
For more see [contributing to
59+
JAX](https://jax.readthedocs.io/en/latest/contributing.html).
60+
61+
(components)=
62+
## A modular stack
63+
64+
To enable (a) a growing community of users across numerical domains,
65+
and (b) an advancing hardware landscape, we lean heavily on
66+
**modularity**.
67+
68+
(upstack)=
69+
### Libraries built on JAX
70+
71+
While the JAX core library focuses on the fundamentals, we want to
72+
encourage domain-specific libraries and tools to be built on top of
73+
JAX. Indeed, [many
74+
libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have
75+
emerged around JAX to offer higher-level features and extensions.
76+
77+
How do we encourage such decentralized development? We guide it with
78+
several technical choices. First, JAX's main API focuses on basic
79+
building blocks (e.g. numerical primitives, NumPy operations, arrays,
80+
and transformations), encouraging auxiliary libraries to develop
81+
utilities as needed for their domain. In addition, JAX exposes a
82+
handful of more advanced APIs for
83+
[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
84+
and
85+
[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries
86+
can [lean on these
87+
APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in
88+
order to use JAX as an internal means of implementation, to integrate
89+
more with its transformations like autodiff, and more.
90+
91+
Projects across the JAX ecosystem are developed in a distributed and
92+
often open fashion. They are not governed by the JAX core team, even
93+
though sometimes team members contribute to them or maintain contact
94+
with their developers.
95+
96+
(downstack)=
97+
### A pluggable backend
98+
99+
We want JAX to run on CPUs, GPUs, TPUs, and other hardware platforms
100+
as they emerge. To encourage unhindered support of JAX on new
101+
platforms, the JAX core emphasizes modularity in its backend too.
102+
103+
To manage hardware devices and memory, and for compilation to such
104+
devices, JAX calls out to the open [XLA
105+
compiler](https://openxla.org/) and the [PJRT
106+
runtime](https://github.com/openxla/xla/tree/main/xla/pjrt/c#pjrt---uniform-device-api). Both
107+
of these are projects external to the JAX core, governed and
108+
maintained by OpenXLA (again, with frequent contributions from and
109+
discussion with the JAX core developers).
110+
111+
XLA aims for interoperability across accelerators (e.g. by ingesting
112+
[StableHLO](https://openxla.org/stablehlo) as input) and PJRT offers
113+
extensibility through a plug-in device API. Adding support for new
114+
devices is done by implementing a backend lowering for XLA, and
115+
implementing a plug-in device API defined by PJRT. If you're looking
116+
to contribute to compilation, or to supporting new hardware, we
117+
encourage you to contribute at the XLA and PJRT layers.
118+
119+
These open system components allow third parties to support JAX on new
120+
accelerator platforms, *without requiring changes in the JAX
121+
core*. There are several plug-ins in development today. For example, a
122+
team at Apple is working on a PJRT plug-in to get [JAX running on
123+
Apple Metal](https://developer.apple.com/metal/jax/).

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ maintains an up-to-date list.
142142
extensions
143143
notes
144144
jax
145+
about
145146

146147

147148
.. toctree::

jax/_src/api.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,12 +2445,8 @@ def _device_get(x):
24452445

24462446
# Extended dtypes dispatch via their device_get rule.
24472447
if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended):
2448-
try:
2449-
to_device = x.dtype._rules.device_get
2450-
except AttributeError:
2451-
pass
2452-
else:
2453-
return to_device(x)
2448+
bufs, tree = tree_util.dispatch_registry.flatten(x)
2449+
return tree.unflatten(device_get(bufs))
24542450

24552451
# Other types dispatch via their __array__ method.
24562452
try:

jax/_src/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
11731173
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
11741174
batch_xs, batch_devs, batch_shardings, batch_cs)
11751175
else:
1176-
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # type: ignore
1176+
copy_outs = xc.batched_copy_array_to_devices_with_sharding( # pytype: disable=missing-parameter
11771177
batch_xs, batch_devs, batch_shardings)
11781178
for i, copy_out in safe_zip(batch_indices, copy_outs):
11791179
assert results[i] is None

jax/_src/compilation_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def is_cache_used(backend: xla_client.Client) -> bool:
8484
_cache_used = True
8585
return _cache_used
8686

87+
return False
88+
8789

8890
def get_file_cache(path: str) -> tuple[CacheInterface, str] | None:
8991
"""Returns the file cache and the path to the cache."""

jax/_src/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323
import sys
2424
import threading
25-
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast, TYPE_CHECKING
25+
from typing import Any, Generic, NamedTuple, NoReturn, Optional, Protocol, TypeVar, cast
2626

2727
from jax._src import lib
2828
from jax._src.lib import guard_lib
@@ -371,7 +371,7 @@ class _Unset: pass
371371

372372
_thread_local_state = threading.local()
373373

374-
class State(Generic[_T]):
374+
class State(Generic[_T]): # type: ignore[no-redef]
375375

376376
__slots__ = (
377377
'_name', '_value', '_update_thread_local_hook', '_update_global_hook',

jax/_src/core.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,11 @@ def unsafe_buffer_pointer(self):
892892
aval_property = namedtuple("aval_property", ["fget"])
893893
aval_method = namedtuple("aval_method", ["fun"])
894894

895+
def check_eval_args(args):
896+
for arg in args:
897+
if isinstance(arg, Tracer):
898+
raise escaped_tracer_error(arg)
899+
895900
class EvalTrace(Trace):
896901

897902
def process_primitive(self, primitive, args, params):
@@ -902,12 +907,11 @@ def process_primitive(self, primitive, args, params):
902907
else:
903908
# TODO(dougalm): delete. this shouldn't be necessary
904909
args = map(full_lower, args)
905-
for arg in args:
906-
if isinstance(arg, Tracer):
907-
if config.data_dependent_tracing_fallback.value:
910+
if config.data_dependent_tracing_fallback.value:
911+
for arg in args:
912+
if isinstance(arg, Tracer):
908913
return primitive.bind_with_trace(arg._trace, args, params)
909-
else:
910-
raise escaped_tracer_error(arg)
914+
check_eval_args(args)
911915
return primitive.impl(*args, **params)
912916

913917
def process_call(self, primitive, f, tracers, params):
@@ -955,6 +959,7 @@ def __eq__(self, other):
955959
@dataclass(frozen=True)
956960
class AxisEnv:
957961
axis_sizes : dict[AxisName, int]
962+
spmd_axis_names : set[AxisName]
958963

959964
def axis_size(self, axis_name):
960965
if axis_name not in self.axis_sizes:
@@ -971,20 +976,24 @@ def axis_names(self):
971976
def pop_pure(self, axis_name):
972977
new_sizes = self.axis_sizes.copy()
973978
new_sizes.pop(axis_name)
974-
return AxisEnv(new_sizes)
979+
return AxisEnv(new_sizes, self.spmd_axis_names)
975980

976981
def extend_pure(self, name_size_pairs):
977982
new_sizes = self.axis_sizes.copy()
978983
new_sizes.update((name, size) for name, size in name_size_pairs
979984
if name is not no_axis_name)
980-
return AxisEnv(new_sizes)
985+
return AxisEnv(new_sizes, self.spmd_axis_names)
986+
987+
def add_spmd_axis_names(self, axis_names):
988+
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
989+
return AxisEnv(self.axis_sizes, new_spmd_axis_names)
981990

982991
def as_hashable_key(self):
983992
return tuple((name, size) for (name, size) in self.axis_sizes.items()
984993
if name is not no_axis_name)
985994

986995
eval_trace = EvalTrace()
987-
top_axis_env = AxisEnv({})
996+
top_axis_env = AxisEnv({}, set())
988997

989998
class TracingContext(threading.local):
990999
trace: Trace | None
@@ -1045,6 +1054,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]):
10451054
finally:
10461055
trace_ctx.set_axis_env(prev)
10471056

1057+
@contextmanager
1058+
def add_spmd_axis_names(axis_names: AxisName | None):
1059+
prev = trace_ctx.axis_env
1060+
try:
1061+
if axis_names is not None:
1062+
trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names))
1063+
yield
1064+
finally:
1065+
trace_ctx.set_axis_env(prev)
1066+
10481067
def get_axis_env():
10491068
return trace_ctx.axis_env
10501069

@@ -2092,33 +2111,6 @@ def get_bind_params(self, params):
20922111
closed_call_p.def_effectful_abstract_eval(
20932112
lambda *_, call_jaxpr: (call_jaxpr.out_avals, call_jaxpr.effects))
20942113

2095-
2096-
outfeed_primitives: set[Primitive] = set()
2097-
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
2098-
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
2099-
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
2100-
for eqn in jaxpr.eqns)
2101-
2102-
def _param_uses_outfeed(param):
2103-
if type(param) is Jaxpr:
2104-
if jaxpr_uses_outfeed(param):
2105-
return True
2106-
elif type(param) is ClosedJaxpr:
2107-
if jaxpr_uses_outfeed(param.jaxpr):
2108-
return True
2109-
return False
2110-
2111-
def primitive_uses_outfeed(prim: Primitive, params: dict) -> bool:
2112-
if prim in outfeed_primitives:
2113-
return True
2114-
for param in params.values():
2115-
if isinstance(param, tuple):
2116-
if any(unsafe_map(_param_uses_outfeed, param)):
2117-
return True
2118-
elif _param_uses_outfeed(param):
2119-
return True
2120-
return False
2121-
21222114
# ------------------- Map -------------------
21232115

21242116
class MapPrimitive(Primitive):

jax/_src/dispatch.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
import atexit
19-
from collections.abc import Callable, Sequence
19+
from collections.abc import Sequence
2020
import contextlib
2121
import dataclasses
2222
import enum
@@ -278,17 +278,6 @@ def _is_bint_axis_size(d: core.AxisSize) -> bool:
278278
return False
279279

280280

281-
# We can optionally set a Jaxpr rewriter that can be applied just before
282-
# compilation. This mechanism is used for compiling id_tap, we can
283-
# remove it once we bring the id_tap implementation into the core.
284-
outfeed_rewriter: Callable[[core.Jaxpr], core.Jaxpr] | None = None
285-
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
286-
if outfeed_rewriter is not None:
287-
return outfeed_rewriter(jaxpr)
288-
else:
289-
return jaxpr
290-
291-
292281
def check_arg(arg: Any):
293282
if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
294283
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "

0 commit comments

Comments
 (0)