Skip to content

Commit 174f0b8

Browse files
authored
Merge pull request #142 from ROCm/ci-upstream-sync-25_1
CI: 11/14/24 upstream sync
2 parents d0f6b95 + 34d9633 commit 174f0b8

File tree

92 files changed

+2193
-818
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+2193
-818
lines changed

.github/workflows/asan.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212
branches:
1313
- main
1414
paths:
15-
- '**/workflows/asan.yml'
15+
- '**/workflows/asan.yaml'
1616

1717
jobs:
1818
asan:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: CI - Bazel CPU tests (RBE)
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
halt-for-connection:
7+
description: 'Should this workflow run wait for a remote connection?'
8+
type: choice
9+
required: true
10+
default: 'no'
11+
options:
12+
- 'yes'
13+
- 'no'
14+
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
17+
cancel-in-progress: true
18+
19+
jobs:
20+
run_tests:
21+
if: github.event.repository.fork == false
22+
strategy:
23+
matrix:
24+
runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"]
25+
26+
runs-on: ${{ matrix.runner }}
27+
# TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available
28+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
29+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }}
30+
31+
env:
32+
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
33+
34+
steps:
35+
- uses: actions/checkout@v3
36+
- name: Wait For Connection
37+
uses: google-ml-infra/actions/ci_connection@main
38+
with:
39+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
40+
- name: Run Bazel CPU Tests with RBE
41+
run: ./ci/run_bazel_test_cpu_rbe.sh

.github/workflows/rocm-nightly-upstream-sync.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Pulls the latest changes from upstream into main and opens a PR to merge
2-
# them into rocm-main.
2+
# them into rocm-main branch.
33

44
name: ROCm Nightly Upstream Sync
55
on:

CHANGELOG.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4040
`platforms` instead.
4141
* Hashing of tracers, which has been deprecated since version 0.4.30, now
4242
results in a `TypeError`.
43+
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
44+
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
45+
on the function inputs.
46+
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
4347

4448
* New Features
4549
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
@@ -49,6 +53,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4953
declared inline via {func}`dataclasses.field`. See the function documentation
5054
for examples.
5155

56+
* Bug fixes
57+
* Fixed a bug where the GPU implementations of LU and QR decomposition would
58+
result in an indexing overflow for batch sizes close to int32 max. See
59+
{jax-issue}`#24843` for more details.
60+
5261
## jax 0.4.35 (Oct 22, 2024)
5362

5463
* Breaking Changes
@@ -79,7 +88,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
7988
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
8089
been deprecated. Use the JAX FFI instead.
8190
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
82-
`jax.lib.xla_client.ops`,
91+
`jax.lib.xla_client.ops`,
8392
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
8493
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
8594
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO

ci/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# JAX continuous integration
2+
3+
> [!WARNING]
4+
> This folder is still under construction. It is part of an ongoing
5+
> effort to improve the structure of CI and build related files within the
6+
> JAX repo. This warning will be removed when the contents of this
7+
> directory are stable and appropriate documentation around its usage is in
8+
> place.
9+
10+
********************************************************************************

ci/envs/default.env

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# This file contains all the default values for the "JAXCI_" environment
16+
# variables used in the CI scripts. These variables are used to control the
17+
# behavior of the CI scripts such as the Python version used, path to JAX/XLA
18+
# repo, if to clone XLA repo, etc.
19+
20+
# The path to the JAX git repository.
21+
export JAXCI_JAX_GIT_DIR=$(pwd)
22+
23+
# Controls the version of Hermetic Python to use. Use system default if not
24+
# set.
25+
export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')}
26+
27+
# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local
28+
# copy of XLA instead of the pinned version in the WORKSPACE. When
29+
# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically.
30+
export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-}
31+
32+
# If set to 1, the builds will clone the XLA repository at HEAD and set its
33+
# path in JAXCI_XLA_GIT_DIR.
34+
export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0}
35+
36+
# Allows overriding the XLA commit that is used.
37+
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-}

ci/run_bazel_test_cpu_rbe.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Runs Bazel CPU tests with RBE.
17+
#
18+
# -e: abort script if one command fails
19+
# -u: error if undefined variable used
20+
# -x: log all commands
21+
# -o history: record shell history
22+
# -o allexport: export all functions and variables to be available to subscripts
23+
set -exu -o history -o allexport
24+
25+
# Source default JAXCI environment variables.
26+
source ci/envs/default.env
27+
28+
# Clone XLA at HEAD if path to local XLA is not provided
29+
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
30+
export JAXCI_CLONE_MAIN_XLA=1
31+
fi
32+
33+
# Set up the build environment.
34+
source "ci/utilities/setup_build_environment.sh"
35+
36+
# Run Bazel CPU tests with RBE.
37+
os=$(uname -s | awk '{print tolower($0)}')
38+
arch=$(uname -m)
39+
40+
# When running on Mac or Linux Aarch64, we only build the test targets and
41+
# not run them. These platforms do not have native RBE support so we
42+
# RBE cross-compile them on remote Linux x86 machines. As the tests still
43+
# need to be run on the host machine and because running the tests on a
44+
# single machine can take a long time, we skip running them on these
45+
# platforms.
46+
if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then
47+
echo "Building RBE CPU tests..."
48+
bazel build --config=rbe_cross_compile_${os}_${arch} \
49+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
50+
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
51+
--test_env=JAX_NUM_GENERATED_CASES=25 \
52+
--test_env=JAX_SKIP_SLOW_TESTS=true \
53+
--action_env=JAX_ENABLE_X64=0 \
54+
--test_output=errors \
55+
--color=yes \
56+
//tests:cpu_tests //tests:backend_independent_tests
57+
else
58+
echo "Running RBE CPU tests..."
59+
bazel test --config=rbe_${os}_${arch} \
60+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
61+
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
62+
--test_env=JAX_NUM_GENERATED_CASES=25 \
63+
--test_env=JAX_SKIP_SLOW_TESTS=true \
64+
--action_env=JAX_ENABLE_X64=0 \
65+
--test_output=errors \
66+
--color=yes \
67+
//tests:cpu_tests //tests:backend_independent_tests
68+
fi
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Set up the build environment for JAX CI jobs. This script depends on the
17+
# "JAXCI_" environment variables set or sourced in the build script.
18+
19+
# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI
20+
# jobs running on Linux runners in GitHub Actions. Without this, git complains
21+
# that the directory has dubious ownership and refuses to run any commands.
22+
# Avoid running on Windows runners as git runs into issues with not being able
23+
# to lock the config file. Other git commands seem to work on the Windows
24+
# runners so we can skip this step for Windows.
25+
# TODO(b/375073267): Remove this once we understand why git repositories are
26+
# being marked as unsafe inside the self-hosted runners.
27+
if [[ ! $(uname -s) =~ "MSYS_NT" ]]; then
28+
git config --global --add safe.directory $JAXCI_JAX_GIT_DIR
29+
fi
30+
31+
function clone_main_xla() {
32+
echo "Cloning XLA at HEAD to $(pwd)/xla"
33+
git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla
34+
export JAXCI_XLA_GIT_DIR=$(pwd)/xla
35+
}
36+
37+
# Clone XLA at HEAD if required.
38+
if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then
39+
# Clone only if $(pwd)/xla does not exist to avoid failure on re-runs.
40+
if [[ ! -d $(pwd)/xla ]]; then
41+
clone_main_xla
42+
else
43+
echo "JAXCI_CLONE_MAIN_XLA set but local XLA folder already exists: $(pwd)/xla so using that instead."
44+
# Set JAXCI_XLA_GIT_DIR if local XLA already exists
45+
export JAXCI_XLA_GIT_DIR=$(pwd)/xla
46+
fi
47+
fi
48+
49+
# If a XLA commit is provided, check out XLA at that commit.
50+
if [[ ! -z "$JAXCI_XLA_COMMIT" ]]; then
51+
# Clone XLA at HEAD if a path to local XLA is not provided.
52+
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
53+
clone_main_xla
54+
fi
55+
pushd "$JAXCI_XLA_GIT_DIR"
56+
57+
git fetch --depth=1 origin "$JAXCI_XLA_COMMIT"
58+
echo "JAXCI_XLA_COMMIT is set. Checking out XLA at $JAXCI_XLA_COMMIT"
59+
git checkout "$JAXCI_XLA_COMMIT"
60+
61+
popd
62+
fi
63+
64+
if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then
65+
echo "INFO: Overriding XLA to be read from $JAXCI_XLA_GIT_DIR instead of the"
66+
echo "pinned version in the WORKSPACE."
67+
echo "If you would like to revert this behavior, unset JAXCI_CLONE_MAIN_XLA"
68+
echo "and JAXCI_XLA_COMMIT in your environment. Note that the Bazel RBE test"
69+
echo "commands overrides the XLA repository and thus require a local copy of"
70+
echo "XLA to run."
71+
fi

docs/Custom_Operation_for_GPUs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ class RmsNormFwdClass:
679679
NamedSharding(mesh, PartitionSpec(None, None)))
680680
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
681681
output_shardings = (arg_shardings[0], invvar_sharding)
682-
# Sharded_impl only accepts positional arugments
682+
# Sharded_impl only accepts positional arguments
683683
# And they should be Jax traceable variables
684684
impl = partial(RmsNormFwdClass.impl, eps=eps)
685685
@@ -739,7 +739,7 @@ class RmsNormBwdClass:
739739
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
740740
741741
742-
# Sharded_impl only accepts positional arugments
742+
# Sharded_impl only accepts positional arguments
743743
# And they should be Jax traceable variables
744744
def impl(g, invvar, x, weight):
745745
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(

docs/Custom_Operation_for_GPUs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def partition(eps: float, mesh : jax.sharding.Mesh,
353353
NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything.
354354
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
355355
output_shardings = (arg_shardings[0], invvar_sharding)
356-
# Sharded_impl only accepts positional arugments
356+
# Sharded_impl only accepts positional arguments
357357
# And they should be Jax traceable variables
358358
impl = partial(RmsNormFwdClass.impl, eps=eps)
359359

0 commit comments

Comments
 (0)