Skip to content

Commit 7a172d7

Browse files
Merge pull request #285 from ROCm/ci-upstream-sync-147_1
CI: 03/14/25 upstream sync
2 parents 022da91 + 7a6940b commit 7a172d7

File tree

104 files changed

+3468
-1030
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+3468
-1030
lines changed

.github/workflows/pytest_cpu.yml

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,8 @@ jobs:
116116
exit 1
117117
- name: Install Python dependencies
118118
run: |
119-
# Remove installation of NVIDIA wheels for CPU tests.
120-
sed -i 's/-r gpu-test-requirements.txt/# -r gpu-test-requirements.txt/g' build/requirements.in
121-
122-
# TODO(srnitin): Remove after uv is installed in the Windows Dockerfile
123119
$JAXCI_PYTHON -m pip install uv~=0.5.30
124-
# python 3.13t cannot compile zstandard 0.23.0 due to
125-
# https://github.com/indygreg/python-zstandard/issues/231. Remove this once zstandard
126-
# has a prebuilt wheel for 3.13t or an env marker is available for free threading python
127-
# in requirements.in.
128-
if [[ $JAXCI_PYTHON =~ "python3.13-nogil" ]]; then
129-
grep -v "zstandard" build/requirements.in > build/requirements_without_zstandard.txt
130-
$JAXCI_PYTHON -m uv pip install -r build/requirements_without_zstandard.txt
131-
else
132-
$JAXCI_PYTHON -m uv pip install -r build/requirements.in
133-
fi
120+
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
134121
# Halt for testing
135122
- name: Wait For Connection
136123
uses: google-ml-infra/actions/ci_connection@main

.github/workflows/pytest_cuda.yml

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,7 @@ jobs:
100100
echo "Skipping the test run."
101101
exit 1
102102
- name: Install Python dependencies
103-
run: |
104-
# python 3.13t cannot compile zstandard 0.23.0 due to
105-
# https://github.com/indygreg/python-zstandard/issues/231. Remove this once zstandard
106-
# has a prebuilt wheel for 3.13t or an env marker is available for free threading python
107-
# in requirements.in.
108-
if [[ $JAXCI_PYTHON =~ "python3.13-nogil" ]]; then
109-
grep -v "zstandard" build/requirements.in > build/requirements_without_zstandard.txt
110-
$JAXCI_PYTHON -m uv pip install -r build/requirements_without_zstandard.txt
111-
else
112-
$JAXCI_PYTHON -m uv pip install -r build/requirements.in
113-
fi
103+
run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
114104
# Halt for testing
115105
- name: Wait For Connection
116106
uses: google-ml-infra/actions/ci_connection@main

.github/workflows/wheel_tests_continuous.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ concurrency:
2929
jobs:
3030
build-jax-artifact:
3131
uses: ./.github/workflows/build_artifacts.yml
32+
name: "Build jax artifact"
3233
with:
3334
# Note that since jax is a pure python package, the runner OS and Python values do not
3435
# matter. In addition, cloning main XLA also has no effect.
@@ -46,6 +47,10 @@ jobs:
4647
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"]
4748
artifact: ["jaxlib"]
4849
python: ["3.10"]
50+
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
51+
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
52+
# values to the name and creates a separate entry for each matrix combination.
53+
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
4954
with:
5055
runner: ${{ matrix.runner }}
5156
artifact: ${{ matrix.artifact }}
@@ -63,6 +68,7 @@ jobs:
6368
runner: ["linux-x86-n2-16"]
6469
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
6570
python: ["3.10",]
71+
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
6672
with:
6773
runner: ${{ matrix.runner }}
6874
artifact: ${{ matrix.artifact }}
@@ -86,6 +92,7 @@ jobs:
8692
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
8793
python: ["3.10",]
8894
enable-x64: [1, 0]
95+
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
8996
with:
9097
runner: ${{ matrix.runner }}
9198
python: ${{ matrix.python }}
@@ -115,6 +122,7 @@ jobs:
115122
- runner: "linux-x86-a3-8g-h100-8gpu"
116123
python: "3.10"
117124
enable-x64: 0
125+
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
118126
with:
119127
runner: ${{ matrix.runner }}
120128
python: ${{ matrix.python }}
@@ -137,6 +145,7 @@ jobs:
137145
runner: ["linux-x86-g2-48-l4-4gpu",]
138146
python: ["3.10",]
139147
enable-x64: [1, 0]
148+
name: "Bazel CUDA Non-RBE (JAX artifacts version = ${{ format('{0}', 'head') }})"
140149
with:
141150
runner: ${{ matrix.runner }}
142151
python: ${{ matrix.python }}
@@ -160,7 +169,7 @@ jobs:
160169
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
161170
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
162171
]
163-
name: "TPU tests (jax=head, jaxlib=head)"
172+
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
164173
with:
165174
runner: ${{ matrix.tpu-specs.runner }}
166175
cores: ${{ matrix.tpu-specs.cores }}

.github/workflows/wheel_tests_nightly_release.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
exclude:
3737
- runner: "windows-x86-n2-64"
3838
python: "3.13-nogil"
39+
name: "Pytest CPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
3940
with:
4041
runner: ${{ matrix.runner }}
4142
python: ${{ matrix.python }}
@@ -53,6 +54,7 @@ jobs:
5354
python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"]
5455
cuda: ["12.3", "12.1"]
5556
enable-x64: [0]
57+
name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
5658
with:
5759
runner: ${{ matrix.runner }}
5860
python: ${{ matrix.python }}
@@ -88,7 +90,7 @@ jobs:
8890
- tpu-specs:
8991
type: "v5e-8"
9092
python: "3.11"
91-
name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
93+
name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
9294
with:
9395
runner: ${{ matrix.tpu-specs.runner }}
9496
cores: ${{ matrix.tpu-specs.cores }}

BUILD.bazel

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ transitive_py_deps(
3131
"//jax:experimental",
3232
"//jax:experimental_colocated_python",
3333
"//jax:experimental_sparse",
34-
"//jax:internal_export_back_compat_test_util",
35-
"//jax:internal_test_harnesses",
36-
"//jax:internal_test_util",
3734
"//jax:lax_reference",
3835
"//jax:pallas_experimental_gpu_ops",
3936
"//jax:pallas_gpu_ops",

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Patch release of 0.5.1
4242
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
4343
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
4444
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
45+
* Added {func}`jax.random.multinomial`.
4546
{jax-issue}`#25955` for more details.
4647

4748
* Changes

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,4 +456,3 @@ For details about the JAX API, see the
456456

457457
For getting started as a JAX developer, see the
458458
[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
459-

benchmarks/api_benchmark.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import google_benchmark
2222
import jax
2323
from jax import lax
24-
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
25-
from jax._src import core
26-
from jax._src.lib import xla_client as xc
2724
from jax._src import array
25+
from jax._src import core
2826
from jax._src import op_shardings
27+
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
28+
from jax._src.lib import xla_client as xc
2929
from jax._src.pjit import pjit_check_aval_sharding
30-
from jax.experimental import pjit as pjit_lib
3130
from jax.experimental import multihost_utils
31+
from jax.experimental import pjit as pjit_lib
3232
import jax.numpy as jnp
3333
import numpy as np
3434

@@ -860,29 +860,44 @@ def safe_zip(state):
860860

861861
@google_benchmark.register
862862
def bench_make_array_from_callback_fully_replicated_sharding(state):
863-
mesh = jax.sharding.Mesh(
864-
np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y'))
865-
shape = (8, 2)
866-
np_arr = np.arange(16).reshape(shape)
867-
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
863+
mesh = create_mesh((4, 2), ('x', 'y'), state)
864+
if mesh is None:
865+
return
866+
input_shape = (8, 2)
867+
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
868868

869+
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
869870
while state:
870-
jax.make_array_from_callback(shape, s, np_arr.__getitem__)
871+
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
871872

872873

873874
@google_benchmark.register
874875
@google_benchmark.option.unit(google_benchmark.kMillisecond)
875-
def bench_make_array_from_callback_sharded(state):
876-
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
876+
def bench_make_array_from_callback_partially_replicated_sharding(state):
877+
mesh = create_mesh((4, 2), ('x', 'y'), state)
878+
if mesh is None:
879+
return
877880
input_shape = (8, 2)
878-
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
881+
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
882+
883+
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y'))
884+
while state:
885+
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
879886

880-
def callback(index):
881-
return input_data[index]
882887

883-
s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y'))
888+
@google_benchmark.register
889+
@google_benchmark.option.unit(google_benchmark.kMillisecond)
890+
def bench_make_array_from_callback_fully_sharded_sharding(state):
891+
mesh = create_mesh((4, 2), ('x', 'y'), state)
892+
if mesh is None:
893+
return
894+
input_shape = (8, 2)
895+
np_arr = np.arange(math.prod(input_shape)).reshape(input_shape)
896+
897+
s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
884898
while state:
885-
jax.make_array_from_callback((8, 2), s, callback)
899+
jax.make_array_from_callback(input_shape, s, np_arr.__getitem__)
900+
886901

887902
@google_benchmark.register
888903
@google_benchmark.option.unit(google_benchmark.kMillisecond)

build/test-requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ flatbuffers
77
hypothesis
88
mpmath>=1.3
99
pillow>=10.4.0
10-
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
10+
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
1111
portpicker; python_version<"3.13"
1212
pytest-xdist
1313
wheel
@@ -19,3 +19,6 @@ matplotlib~=3.8.4; python_version=="3.10"
1919
matplotlib; python_version>="3.11"
2020
opt-einsum
2121
auditwheel
22+
23+
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
24+
numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64"

docs/_static/style.css

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,13 @@ html[data-theme="light"] .highlight span.gt {
296296
html[data-theme="light"] .highlight span.gr {
297297
color: #ff0000;
298298
}
299+
300+
.header-article-items__start {
301+
display: flex;
302+
flex-direction: row;
303+
gap: 0.5em;
304+
}
305+
306+
.bd-breadcrumbs {
307+
margin-bottom: 0;
308+
}

0 commit comments

Comments
 (0)