Skip to content

Commit 3be7c1e

Browse files
authored
Merge pull request #148 from ROCm/ci-upstream-sync-34_1
CI: 11/22/24 upstream sync
2 parents 8607cb6 + 846697f commit 3be7c1e

File tree

160 files changed

+6010
-2401
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

160 files changed

+6010
-2401
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: CI - Bazel GPU tests (RBE)
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
halt-for-connection:
7+
description: 'Should this workflow run wait for a remote connection?'
8+
type: choice
9+
required: true
10+
default: 'no'
11+
options:
12+
- 'yes'
13+
- 'no'
14+
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
17+
cancel-in-progress: true
18+
19+
jobs:
20+
run_tests:
21+
if: github.event.repository.fork == false
22+
strategy:
23+
matrix:
24+
runner: ["linux-x86-n2-16"]
25+
26+
runs-on: ${{ matrix.runner }}
27+
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
28+
29+
env:
30+
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
31+
32+
steps:
33+
- uses: actions/checkout@v3
34+
- name: Wait For Connection
35+
uses: google-ml-infra/actions/ci_connection@main
36+
with:
37+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
38+
- name: Run Bazel GPU Tests with RBE
39+
run: ./ci/run_bazel_test_gpu_rbe.sh

.github/workflows/ci-build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
documentation_render:
140140
name: Documentation - render documentation
141141
runs-on: ubuntu-latest
142-
timeout-minutes: 10
142+
timeout-minutes: 20
143143
strategy:
144144
matrix:
145145
python-version: ['3.10']

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
name: CI - Cloud TPU (nightly)
1414
on:
1515
schedule:
16-
- cron: "0 14 * * *" # daily at 7am PST
16+
- cron: "0 */2 * * *" # Run every 2 hours
1717
workflow_dispatch: # allows triggering the workflow run manually
1818
# This should also be set to read-only in the project settings, but it's nice to
1919
# document and enforce the permissions here.
@@ -26,15 +26,18 @@ jobs:
2626
matrix:
2727
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828
tpu: [
29-
{type: "v3-8", cores: "4"},
30-
{type: "v4-8", cores: "4"},
31-
{type: "v5e-8", cores: "8"}
29+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
30+
# {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
31+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
3232
]
33+
python-version: ["3.10"]
3334
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3435
env:
3536
LIBTPU_OLDEST_VERSION_DATE: 20240722
3637
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
37-
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
38+
PYTHON: python${{ matrix.python-version }}
39+
runs-on: ${{ matrix.tpu.runner }}
40+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
3841
timeout-minutes: 120
3942
defaults:
4043
run:
@@ -46,52 +49,52 @@ jobs:
4649
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
4750
- name: Install JAX test requirements
4851
run: |
49-
pip install -U -r build/test-requirements.txt
50-
pip install -U -r build/collect-profile-requirements.txt
52+
$PYTHON -m pip install -U -r build/test-requirements.txt
53+
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5154
- name: Install JAX
5255
run: |
53-
pip uninstall -y jax jaxlib libtpu
56+
$PYTHON -m pip uninstall -y jax jaxlib libtpu
5457
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
55-
pip install .[tpu] \
58+
$PYTHON -m pip install .[tpu] \
5659
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5760
5861
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
59-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60-
pip install --pre libtpu \
62+
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
63+
$PYTHON -m pip install --pre libtpu \
6164
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
62-
pip install requests
65+
$PYTHON -m pip install requests
6366
6467
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
68+
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6669
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67-
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
70+
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6871
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69-
pip install requests
72+
$PYTHON -m pip install requests
7073
else
7174
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7275
exit 1
7376
fi
7477
75-
python3 -c 'import sys; print("python version:", sys.version)'
76-
python3 -c 'import jax; print("jax version:", jax.__version__)'
77-
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
78-
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
79-
python3 -c 'import jax; print("libtpu version:",
78+
$PYTHON -c 'import sys; print("python version:", sys.version)'
79+
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
80+
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
81+
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
82+
$PYTHON -c 'import jax; print("libtpu version:",
8083
jax.lib.xla_bridge.get_backend().platform_version)'
8184
- name: Run tests
8285
env:
8386
JAX_PLATFORMS: tpu,cpu
8487
PY_COLORS: 1
8588
run: |
8689
# Run single-accelerator tests in parallel
87-
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
90+
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
8891
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
8992
--maxfail=20 -m "not multiaccelerator" tests examples
9093
# Run Pallas printing tests, which need to run with I/O capturing disabled.
91-
TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
94+
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
9295
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
9396
# Run multi-accelerator across all chips
94-
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
97+
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
9598
- name: Send chat on failure
9699
# Don't notify when testing the workflow from a branch.
97100
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻
31+
ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}

CHANGELOG.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1313
## jax 0.4.36
1414

1515
* Breaking Changes
16+
* This release lands "stackless", an internal change to JAX's tracing
17+
machinery. We made trace dispatch purely a function of context rather than a
18+
function of both context and data. This let us delete a lot of machinery for
19+
managing data-dependent tracing: levels, sublevels, `post_process_call`,
20+
`new_base_main`, `custom_bind`, and so on. The change should only affect
21+
users that use JAX internals.
22+
23+
If you do use JAX internals then you may need to
24+
update your code (see
25+
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
26+
for clues about how to do this). There might also be version skew
27+
issues with JAX libraries that do this. If you find this change breaks your
28+
non-JAX-internals-using code then try the
29+
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
30+
you need help updating your code then please file a bug.
1631
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
1732
or with `enable_xla=False` have been deprecated since July 2024, with
1833
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
@@ -43,6 +58,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4358
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
4459
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
4560
on the function inputs.
61+
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
62+
return NaN for negative integer inputs, to match the behavior of SciPy from
63+
https://github.com/scipy/scipy/pull/21827.
4664
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
4765

4866
* New Features
@@ -52,12 +70,22 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
5270
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
5371
declared inline via {func}`dataclasses.field`. See the function documentation
5472
for examples.
73+
* Added {func}`jax.numpy.put_along_axis`.
74+
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
75+
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
76+
supported on GPU. See {jax-issue}`#24663` for more details.
5577

5678
* Bug fixes
5779
* Fixed a bug where the GPU implementations of LU and QR decomposition would
5880
result in an indexing overflow for batch sizes close to int32 max. See
5981
{jax-issue}`#24843` for more details.
6082

83+
* Deprecations
84+
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
85+
use `jax.Array` instead.
86+
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
87+
instead.
88+
6189
## jax 0.4.35 (Oct 22, 2024)
6290

6391
* Breaking Changes

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.
189189

190190
Using `jit` puts constraints on the kind of Python control flow
191191
the function can use; see
192-
the [Gotchas
193-
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
192+
the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
194193
for more.
195194

196195
### Auto-vectorization with `vmap`
@@ -349,7 +348,7 @@ Some standouts:
349348
1. [In-place mutating updates of
350349
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
351350
1. [Random numbers are
352-
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
351+
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
353352
1. If you're looking for [convolution
354353
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
355354
they're in the `jax.lax` package.
@@ -369,7 +368,7 @@ Some standouts:
369368
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
370369
np.float32)).dtype` is `float64` rather than `float32`.
371370
1. Some transformations, like `jit`, [constrain how you can use Python control
372-
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
371+
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
373372
You'll always get loud errors if something goes wrong. You might have to use
374373
[`jit`'s `static_argnums`
375374
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
@@ -390,6 +389,7 @@ Some standouts:
390389
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
391390
| AMD GPU | yes | no | experimental | n/a | no | no |
392391
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
392+
| Intel GPU | experimental | n/a | n/a | n/a | no | no |
393393

394394

395395
### Instructions
@@ -401,6 +401,7 @@ Some standouts:
401401
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
402402
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
403403
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
404+
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |
404405

405406
See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
406407
for information on alternative installation strategies. These include compiling

benchmarks/shape_poly_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import jax
1919
from jax import core
20-
from jax._src.numpy import lax_numpy
2120
from jax import export
2221

2322
jax.config.parse_flags_with_absl()
@@ -76,7 +75,7 @@ def inequalities_slice(state):
7675
while state:
7776
for _ in range(30):
7877
a.scope._clear_caches()
79-
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
78+
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
8079
_ = 0 <= slice_size <= b
8180
_ = start >= 0
8281
_ = start + slice_size <= b

ci/run_bazel_test_gpu_rbe.sh

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one
17+
# GPU apiece on RBE.
18+
#
19+
# -e: abort script if one command fails
20+
# -u: error if undefined variable used
21+
# -x: log all commands
22+
# -o history: record shell history
23+
# -o allexport: export all functions and variables to be available to subscripts
24+
set -exu -o history -o allexport
25+
26+
# Source default JAXCI environment variables.
27+
source ci/envs/default.env
28+
29+
# Clone XLA at HEAD if path to local XLA is not provided
30+
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
31+
export JAXCI_CLONE_MAIN_XLA=1
32+
fi
33+
34+
# Set up the build environment.
35+
source "ci/utilities/setup_build_environment.sh"
36+
37+
# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece).
38+
echo "Running RBE GPU tests..."
39+
40+
bazel test --config=rbe_linux_x86_64_cuda \
41+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
42+
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
43+
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
44+
--test_output=errors \
45+
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
46+
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
47+
--test_tag_filters=-multiaccelerator \
48+
--test_env=JAX_SKIP_SLOW_TESTS=true \
49+
--action_env=JAX_ENABLE_X64=0 \
50+
--color=yes \
51+
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests

docs/Custom_Operation_for_GPUs.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,16 +623,16 @@ be used with the custom_partitioning registration and for the
623623
gradient. (And if you implement the interface to support vmat, it will
624624
also be on the outer primitive).
625625
626-
JAX custom_partitioning implementation are callbacks from XLA to Python during XLA sharding logic.
626+
JAX custom_partitioning implementations are callbacks from XLA to Python during XLA sharding logic.
627627
XLA sharding goes in two phases: a sharding propagation phase and a partition phase.
628-
The propagation phase is when XLA plan the sharding to be created. It is the partition phase that create the sharded graph.
628+
The propagation phase is when XLA plan the sharding to be created. It is the partition phase that creates the sharded graph.
629629
For XLA to be able to shard our custom operations, it needs us to define 2 extra functions:
630630
infer_sharding_from_operands() and partition(). They are used in the first and second phase respectively.
631631
632632
The infer_sharding_from_operands() function must do what its name say: infer the output sharding from the input sharding.
633633
634634
The partition() function will do a few things:
635-
- tell which input sharding will be expected. XLA will reshad if needed.
635+
- tell which input sharding will be expected. XLA will reshard if needed.
636636
- tell the final version of the output sharding.
637637
- give a function that will create the new instruction from the sharded inputs.
638638

0 commit comments

Comments
 (0)