Skip to content

Commit c65ce4b

Browse files
authored
Merge branch 'main' into add-optimization-effort-flags
2 parents 83b54d9 + df6758f commit c65ce4b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1097
-323
lines changed

.github/workflows/asan.yaml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,8 @@ jobs:
2525
run:
2626
shell: bash -l {0}
2727
steps:
28-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
29-
with:
30-
path: jax
31-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
32-
with:
33-
repository: python/cpython
34-
path: cpython
35-
ref: v3.13.0
28+
# Install git before actions/checkout as otherwise it will download the code with the GitHub
29+
# REST API and therefore any subsequent git commands will fail.
3630
- name: Install clang 18
3731
env:
3832
DEBIAN_FRONTEND: noninteractive
@@ -42,6 +36,14 @@ jobs:
4236
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
4337
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
4438
libffi-dev liblzma-dev
39+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
40+
with:
41+
path: jax
42+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
43+
with:
44+
repository: python/cpython
45+
path: cpython
46+
ref: v3.13.0
4547
- name: Build CPython with ASAN enabled
4648
env:
4749
ASAN_OPTIONS: detect_leaks=0

.github/workflows/ci-build.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,19 @@ jobs:
144144
145145
documentation_render:
146146
name: Documentation - render documentation
147-
runs-on: ubuntu-latest
147+
runs-on: linux-x86-n2-16
148+
container:
149+
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
148150
timeout-minutes: 10
149151
strategy:
150152
matrix:
151153
python-version: ['3.10']
152154
steps:
153155
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
156+
- name: Image Setup
157+
run: |
158+
apt update
159+
apt install -y libssl-dev libsqlite3-dev
154160
- name: Set up Python ${{ matrix.python-version }}
155161
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
156162
with:
@@ -170,8 +176,7 @@ jobs:
170176
pip install -r docs/requirements.txt
171177
- name: Render documentation
172178
run: |
173-
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
174-
179+
sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
175180
176181
jax2tf_test:
177182
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"

build/build.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,11 @@ async def main():
399399
else:
400400
requirements_command.append("//build:requirements.update")
401401

402-
await executor.run(requirements_command.get_command_as_string(), args.dry_run)
403-
sys.exit(0)
402+
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run)
403+
if result.return_code != 0:
404+
raise RuntimeError(f"Command failed with return code {result.return_code}")
405+
else:
406+
sys.exit(0)
404407

405408
wheel_cpus = {
406409
"darwin_arm64": "arm64",
@@ -594,7 +597,13 @@ async def main():
594597

595598
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
596599

597-
await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
600+
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
601+
# Exit with error if any wheel build fails.
602+
if result.return_code != 0:
603+
raise RuntimeError(f"Command failed with return code {result.return_code}")
604+
605+
# Exit with success if all wheels in the list were built successfully.
606+
sys.exit(0)
598607

599608

600609
if __name__ == "__main__":

build/rocm/dev_build_rocm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path):
7777
build_command = [
7878
"python3",
7979
"./build/build.py",
80-
"--enable_rocm",
81-
"--build_gpu_plugin",
82-
"--gpu_plugin_rocm_version=60",
80+
"build"
8381
f"--use_clang={str(use_clang).lower()}",
82+
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt"
83+
"--rocm_path=%/opt/rocm-{rocm_version}/",
84+
"--rocm_version=60",
8485
f"--rocm_amdgpu_targets={rocm_target}",
85-
f"--rocm_path=/opt/rocm-{rocm_version}/",
8686
bazel_options,
87+
"--verbose"
8788
]
8889

8990
if clang_option:

build/rocm/tools/build_wheels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,12 @@ def build_jaxlib_wheel(
9393
cmd = [
9494
"python",
9595
"build/build.py",
96-
"--enable_rocm",
97-
"--build_gpu_plugin",
98-
"--gpu_plugin_rocm_version=60",
96+
"build"
97+
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt"
9998
"--rocm_path=%s" % rocm_path,
99+
"--rocm_version=60",
100100
"--use_clang=%s" % use_clang,
101+
"--verbose"
101102
]
102103

103104
# Add clang path if clang is used.

ci/build_artifacts.sh

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
##
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Build JAX artifacts.
17+
# Usage: ./ci/build_artifacts.sh "<artifact>"
18+
# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt
19+
# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib"
20+
#
21+
# -e: abort script if one command fails
22+
# -u: error if undefined variable used
23+
# -x: log all commands
24+
# -o history: record shell history
25+
# -o allexport: export all functions and variables to be available to subscripts
26+
set -exu -o history -o allexport
27+
28+
artifact="$1"
29+
30+
# Source default JAXCI environment variables.
31+
source ci/envs/default.env
32+
33+
# Set up the build environment.
34+
source "ci/utilities/setup_build_environment.sh"
35+
36+
allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt")
37+
38+
os=$(uname -s | awk '{print tolower($0)}')
39+
arch=$(uname -m)
40+
41+
# Adjust the values when running on Windows x86 to match the config in
42+
# .bazelrc
43+
if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
44+
os="windows"
45+
arch="amd64"
46+
fi
47+
48+
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
49+
50+
# Build the jax artifact
51+
if [[ "$artifact" == "jax" ]]; then
52+
python -m build --outdir $JAXCI_OUTPUT_DIR
53+
else
54+
55+
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
56+
# flags in the .bazelrc depending upon the platform we are building for.
57+
bazelrc_config="${os}_${arch}"
58+
59+
# TODO(b/379903748): Add remote cache options for Linux and Windows.
60+
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
61+
bazelrc_config="rbe_${bazelrc_config}"
62+
else
63+
bazelrc_config="ci_${bazelrc_config}"
64+
fi
65+
66+
# Use the "_cuda" configs when building the CUDA artifacts.
67+
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
68+
bazelrc_config="${bazelrc_config}_cuda"
69+
fi
70+
71+
# Build the artifact.
72+
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
73+
74+
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
75+
# run `auditwheel show` to verify manylinux compliance.
76+
if [[ "$os" == "linux" ]]; then
77+
./ci/utilities/run_auditwheel.sh
78+
fi
79+
80+
fi
81+
82+
else
83+
echo "Error: Invalid artifact: $artifact. Allowed values are: ${allowed_artifacts[@]}"
84+
exit 1
85+
fi

ci/envs/default.env

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,26 @@ export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-}
3434
export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0}
3535

3636
# Allows overriding the XLA commit that is used.
37-
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-}
37+
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-}
38+
39+
# Controls the location where the artifacts are written to.
40+
export JAXCI_OUTPUT_DIR="$(pwd)/dist"
41+
42+
# When enabled, artifacts will be built with RBE. Requires gcloud authentication
43+
# and only certain platforms support RBE. Therefore, this flag is enabled only
44+
# for CI builds where RBE is supported.
45+
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
46+
47+
# #############################################################################
48+
# Test script specific environment variables.
49+
# #############################################################################
50+
# The maximum number of tests to run per GPU when running single accelerator
51+
# tests with parallel execution with Bazel. The GPU limit is set because we
52+
# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we
53+
# use L4 machines which have 24GB of RAM but can be overriden if we use a
54+
# different GPU type.
55+
export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12}
56+
57+
# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override
58+
# this value in the Github action workflow files.
59+
export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0}

ci/run_bazel_test_cpu_rbe.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] )
5050
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
5151
--test_env=JAX_NUM_GENERATED_CASES=25 \
5252
--test_env=JAX_SKIP_SLOW_TESTS=true \
53-
--action_env=JAX_ENABLE_X64=0 \
53+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
5454
--test_output=errors \
5555
--color=yes \
5656
//tests:cpu_tests //tests:backend_independent_tests
@@ -61,7 +61,7 @@ else
6161
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
6262
--test_env=JAX_NUM_GENERATED_CASES=25 \
6363
--test_env=JAX_SKIP_SLOW_TESTS=true \
64-
--action_env=JAX_ENABLE_X64=0 \
64+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
6565
--test_output=errors \
6666
--color=yes \
6767
//tests:cpu_tests //tests:backend_independent_tests

ci/run_bazel_test_gpu_non_rbe.sh

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Run Bazel GPU tests without RBE. This runs two commands: single accelerator
17+
# tests with one GPU a piece, multiaccelerator tests with all GPUS.
18+
# Requires that jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are stored
19+
# inside the ../dist folder
20+
#
21+
# -e: abort script if one command fails
22+
# -u: error if undefined variable used
23+
# -x: log all commands
24+
# -o history: record shell history
25+
# -o allexport: export all functions and variables to be available to subscripts
26+
set -exu -o history -o allexport
27+
28+
# Source default JAXCI environment variables.
29+
source ci/envs/default.env
30+
31+
# Set up the build environment.
32+
source "ci/utilities/setup_build_environment.sh"
33+
34+
# Run Bazel GPU tests (single accelerator and multiaccelerator tests) directly
35+
# on the VM without RBE.
36+
nvidia-smi
37+
echo "Running single accelerator tests (without RBE)..."
38+
39+
# Set up test environment variables.
40+
export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
41+
export num_test_jobs=$((gpu_count * JAXCI_MAX_TESTS_PER_GPU))
42+
export num_cpu_cores=$(nproc)
43+
44+
# tests_jobs = max(gpu_count * max_tests_per_gpu, num_cpu_cores)
45+
if [[ $num_test_jobs -gt $num_cpu_cores ]]; then
46+
num_test_jobs=$num_cpu_cores
47+
fi
48+
# End of test environment variables setup.
49+
50+
# Runs single accelerator tests with one GPU apiece.
51+
# It appears --run_under needs an absolute path.
52+
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
53+
# should match the VM's CPU core count (set in `--local_test_jobs`).
54+
bazel test --config=ci_linux_x86_64_cuda \
55+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
56+
--//jax:build_jaxlib=false \
57+
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
58+
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \
59+
--test_output=errors \
60+
--test_env=JAX_ACCELERATOR_COUNT=$gpu_count \
61+
--test_env=JAX_TESTS_PER_ACCELERATOR=$JAXCI_MAX_TESTS_PER_GPU \
62+
--local_test_jobs=$num_test_jobs \
63+
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
64+
--test_tag_filters=-multiaccelerator \
65+
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
66+
--test_env=JAX_SKIP_SLOW_TESTS=true \
67+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
68+
--action_env=NCCL_DEBUG=WARN \
69+
--color=yes \
70+
//tests:gpu_tests //tests:backend_independent_tests \
71+
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
72+
73+
echo "Running multi-accelerator tests (without RBE)..."
74+
# Runs multiaccelerator tests with all GPUs directly on the VM without RBE..
75+
bazel test --config=ci_linux_x86_64_cuda \
76+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
77+
--//jax:build_jaxlib=false \
78+
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
79+
--test_output=errors \
80+
--jobs=8 \
81+
--test_tag_filters=multiaccelerator \
82+
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
83+
--test_env=JAX_SKIP_SLOW_TESTS=true \
84+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
85+
--action_env=NCCL_DEBUG=WARN \
86+
--color=yes \
87+
//tests:gpu_tests //tests/pallas:gpu_tests

ci/run_bazel_test_gpu_rbe.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ bazel test --config=rbe_linux_x86_64_cuda \
4646
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
4747
--test_tag_filters=-multiaccelerator \
4848
--test_env=JAX_SKIP_SLOW_TESTS=true \
49-
--action_env=JAX_ENABLE_X64=0 \
49+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
5050
--color=yes \
5151
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests

0 commit comments

Comments
 (0)