Skip to content

Commit 6ee76a8

Browse files
Merge pull request #271 from ROCm/ci-upstream-sync-142_1
CI: 03/11/25 upstream sync
2 parents ce53e37 + fb89a4b commit 6ee76a8

File tree

147 files changed

+6370
-2093
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

147 files changed

+6370
-2093
lines changed

.bazelrc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ build:macos --apple_platform_type=macos
5454
build:macos --linkopt=-Wl,-undefined,dynamic_lookup
5555
build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup
5656

57+
# Use cc toolchains from apple_support for Apple builds.
58+
# https://github.com/bazelbuild/apple_support/tree/master?tab=readme-ov-file#bazel-6-setup
59+
build:macos --apple_crosstool_top=@local_config_apple_cc//:toolchain
60+
build:macos --crosstool_top=@local_config_apple_cc//:toolchain
61+
build:macos --host_crosstool_top=@local_config_apple_cc//:toolchain
62+
5763
# Windows has a relatively short command line limit, which JAX has begun to hit.
5864
# See https://docs.bazel.build/versions/main/windows.html
5965
build:windows --features=compiler_param_file

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'd982a6245400295477f5da5afa1c4a2a5e641ea4' # Latest commit as of 2025-01-30
31+
ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}

.github/workflows/pytest_cpu.yml

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ on:
2929
type: string
3030
required: true
3131
default: "0"
32-
install-jax-current-commit:
33-
description: "Should the 'jax' package be installed from the current commit?"
34-
type: string
35-
required: true
36-
default: "1"
3732
gcs_download_uri:
3833
description: "GCS location prefix from where the artifacts should be downloaded"
3934
required: true
@@ -62,7 +57,6 @@ jobs:
6257
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
6358
JAXCI_PYTHON: "python${{ inputs.python }}"
6459
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
65-
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
6660

6761
steps:
6862
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -88,22 +82,18 @@ jobs:
8882
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
8983
# `*-cp<py_version>-cp<py_version>t-*`.
9084
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
91-
- name: Download jaxlib wheel from GCS (non-Windows runs)
85+
- name: Download wheels from GCS (non-Windows runs)
9286
id: download-wheel-artifacts-nw
9387
# Set continue-on-error to true to prevent actions from failing the workflow if this step
9488
# fails. Instead, we verify the outcome in the step below so that we can print a more
9589
# informative error message.
9690
continue-on-error: true
9791
if: ${{ !contains(inputs.runner, 'windows-x86') }}
9892
run: |
99-
mkdir -p $(pwd)/dist &&
93+
mkdir -p $(pwd)/dist
94+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
10095
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
101-
102-
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
103-
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
104-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
105-
fi
106-
- name: Download jaxlib wheel from GCS (Windows runs)
96+
- name: Download wheels from GCS (Windows runs)
10797
id: download-wheel-artifacts-w
10898
# Set continue-on-error to true to prevent actions from failing the workflow if this step
10999
# fails. Instead, we verify the outcome in step below so that we can print a more
@@ -115,12 +105,8 @@ jobs:
115105
mkdir dist
116106
@REM Use `call` so that we can run sequential gsutil commands on Windows
117107
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
108+
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
118109
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
119-
120-
@REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
121-
if not "${{ inputs.install-jax-current-commit }}"=="1" (
122-
call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
123-
)
124110
- name: Skip the test run if the wheel artifacts were not downloaded successfully
125111
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
126112
run: |

.github/workflows/pytest_cuda.yml

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ on:
3434
type: string
3535
required: true
3636
default: "0"
37-
install-jax-current-commit:
38-
description: "Should the 'jax' package be installed from the current commit?"
39-
type: string
40-
required: true
41-
default: "1"
4237
gcs_download_uri:
4338
description: "GCS location prefix from where the artifacts should be downloaded"
4439
required: true
@@ -66,7 +61,6 @@ jobs:
6661
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
6762
JAXCI_PYTHON: "python${{ inputs.python }}"
6863
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
69-
JAXCI_INSTALL_JAX_CURRENT_COMMIT: "${{ inputs.install-jax-current-commit }}"
7064

7165
steps:
7266
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -86,22 +80,18 @@ jobs:
8680
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
8781
# `*-cp<py_version>-cp<py_version>t-*`.
8882
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
89-
- name: Download the wheel artifacts from GCS
83+
- name: Download wheels from GCS
9084
id: download-wheel-artifacts
9185
# Set continue-on-error to true to prevent actions from failing the workflow if this step
9286
# fails. Instead, we verify the outcome in the next step so that we can print a more
9387
# informative error message.
9488
continue-on-error: true
9589
run: |
9690
mkdir -p $(pwd)/dist &&
91+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ &&
9792
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
9893
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
9994
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
100-
101-
# Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
102-
if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
103-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
104-
fi
10595
- name: Skip the test run if the wheel artifacts were not downloaded successfully
10696
if: steps.download-wheel-artifacts.outcome == 'failure'
10797
run: |

.github/workflows/tsan-suppressions.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ race_top:PyMember_GetOne
2626
# https://github.com/python/cpython/issues/129547
2727
race:type_get_annotations
2828

29-
# https://github.com/python/cpython/issues/130547
30-
race:split_keys_entry_added
3129

3230
# https://github.com/python/cpython/issues/129748
3331
race:mi_block_set_nextx
@@ -64,3 +62,6 @@ race:gemm_oncopy
6462

6563
# https://github.com/python/cpython/issues/130571
6664
# race:_PyObject_GetMethod
65+
66+
# https://github.com/python/cpython/issues/130547
67+
# race:split_keys_entry_added

.github/workflows/tsan.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \
3636
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
3737
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
38-
libffi-dev liblzma-dev
38+
libffi-dev liblzma-dev file zip
3939
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
4040
with:
4141
path: jax

.github/workflows/wheel_tests_continuous.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ concurrency:
2727
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
2828

2929
jobs:
30+
build-jax-artifact:
31+
uses: ./.github/workflows/build_artifacts.yml
32+
with:
33+
# Note that since jax is a pure python package, the runner OS and Python values do not
34+
# matter. In addition, cloning main XLA also has no effect.
35+
runner: "linux-x86-n2-16"
36+
artifact: "jax"
37+
upload_artifacts_to_gcs: true
38+
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
39+
3040
build-jaxlib-artifact:
3141
uses: ./.github/workflows/build_artifacts.yml
3242
strategy:
@@ -66,7 +76,7 @@ jobs:
6676
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
6777
# still want to run the tests for other platforms.
6878
if: ${{ !cancelled() }}
69-
needs: build-jaxlib-artifact
79+
needs: [build-jax-artifact, build-jaxlib-artifact]
7080
uses: ./.github/workflows/pytest_cpu.yml
7181
strategy:
7282
fail-fast: false # don't cancel all jobs on failure
@@ -80,15 +90,14 @@ jobs:
8090
runner: ${{ matrix.runner }}
8191
python: ${{ matrix.python }}
8292
enable-x64: ${{ matrix.enable-x64 }}
83-
install-jax-current-commit: 1
8493
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
8594

8695
run-pytest-cuda:
8796
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
8897
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
8998
# still want to run the tests for other platforms.
9099
if: ${{ !cancelled() }}
91-
needs: [build-jaxlib-artifact, build-cuda-artifacts]
100+
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
92101
uses: ./.github/workflows/pytest_cuda.yml
93102
strategy:
94103
fail-fast: false # don't cancel all jobs on failure
@@ -111,7 +120,6 @@ jobs:
111120
python: ${{ matrix.python }}
112121
cuda: ${{ matrix.cuda }}
113122
enable-x64: ${{ matrix.enable-x64 }}
114-
install-jax-current-commit: 1
115123
# GCS upload URI is the same for both artifact build jobs
116124
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
117125

.github/workflows/wheel_tests_nightly_release.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ jobs:
4040
runner: ${{ matrix.runner }}
4141
python: ${{ matrix.python }}
4242
enable-x64: ${{ matrix.enable-x64 }}
43-
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
44-
# GCS bucket.
45-
install-jax-current-commit: 0
4643
gcs_download_uri: ${{inputs.gcs_download_uri}}
4744

4845
run-pytest-cuda:
@@ -61,7 +58,4 @@ jobs:
6158
python: ${{ matrix.python }}
6259
cuda: ${{ matrix.cuda }}
6360
enable-x64: ${{ matrix.enable-x64 }}
64-
# Don't install "jax" at head. Instead install the nightly/release "jax" wheels found in the
65-
# GCS bucket.
66-
install-jax-current-commit: 0
6761
gcs_download_uri: ${{inputs.gcs_download_uri}}

BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
15+
load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
1616
load(
1717
"//jaxlib:jax.bzl",
1818
"jax_wheel",

CHANGELOG.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2323
true, matching the current behavior. If set to false, JAX does not need to
2424
emit code clamping negative indices, which improves code size.
2525

26+
## jax 0.5.2 (Mar 4, 2025)
27+
28+
Patch release of 0.5.1
29+
30+
* Bug fixes
31+
* Fixes TPU metric logging and `tpu-info`, which was broken in 0.5.1
32+
2633
## jax 0.5.1 (Feb 24, 2025)
2734

2835
* New Features
@@ -54,6 +61,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
5461
A downstream effect of this several other internal functions need debug
5562
info. This change does not affect public APIs.
5663
See https://github.com/jax-ml/jax/issues/26480 for more detail.
64+
* In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
65+
non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
5766

5867
* Bug fixes
5968
* TPU runtime startup and shutdown time should be significantly improved on
@@ -169,8 +178,6 @@ to signify this.
169178

170179
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
171180

172-
## jax 0.4.37
173-
174181
* Bug fixes
175182
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
176183
* Fix a bug that will throw `index out of range` error in

0 commit comments

Comments
 (0)