Skip to content

Commit f7e9f62

Browse files
nitins17Google-ML-Automation
authored andcommitted
Add new CI scripts for building JAX artifacts
This commit introduces new CI scripts and environment files for building JAX artifacts. It makes use of the artifact envs inside the "ci/envs/build_artifacts" folder to control the build behavior. For e.g: for building jaxlib, we will need to run `./ci/build_artifacts.sh ./ci/envs/build_artifacts/jaxlib.env` from the JAX GitHub root. PiperOrigin-RevId: 700104283
1 parent 788f493 commit f7e9f62

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

ci/build_artifacts.sh

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
##
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Build JAX artifacts.
17+
# Usage: ./ci/build_artifacts.sh "<artifact>"
18+
# Supported artifact values are: jax, jaxlib, jax-cuda-plugin, jax-cuda-pjrt
19+
# E.g: ./ci/build_artifacts.sh "jax" or ./ci/build_artifacts.sh "jaxlib"
20+
#
21+
# -e: abort script if one command fails
22+
# -u: error if undefined variable used
23+
# -x: log all commands
24+
# -o history: record shell history
25+
# -o allexport: export all functions and variables to be available to subscripts
26+
set -exu -o history -o allexport
27+
28+
artifact="$1"
29+
30+
# Source default JAXCI environment variables.
31+
source ci/envs/default.env
32+
33+
# Set up the build environment.
34+
source "ci/utilities/setup_build_environment.sh"
35+
36+
allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt")
37+
38+
os=$(uname -s | awk '{print tolower($0)}')
39+
arch=$(uname -m)
40+
41+
# Adjust the values when running on Windows x86 to match the config in
42+
# .bazelrc
43+
if [[ $os =~ "msys_nt" && $arch == "x86_64" ]]; then
44+
os="windows"
45+
arch="amd64"
46+
fi
47+
48+
if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
49+
50+
# Build the jax artifact
51+
if [[ "$artifact" == "jax" ]]; then
52+
python -m build --outdir $JAXCI_OUTPUT_DIR
53+
else
54+
55+
# Figure out the bazelrc config to use. We will use one of the "rbe_"/"ci_"
56+
# flags in the .bazelrc depending upon the platform we are building for.
57+
bazelrc_config="${os}_${arch}"
58+
59+
# TODO(b/379903748): Add remote cache options for Linux and Windows.
60+
if [[ "$JAXCI_BUILD_ARTIFACT_WITH_RBE" == 1 ]]; then
61+
bazelrc_config="rbe_${bazelrc_config}"
62+
else
63+
bazelrc_config="ci_${bazelrc_config}"
64+
fi
65+
66+
# Use the "_cuda" configs when building the CUDA artifacts.
67+
if [[ ("$artifact" == "jax-cuda-plugin") || ("$artifact" == "jax-cuda-pjrt") ]]; then
68+
bazelrc_config="${bazelrc_config}_cuda"
69+
fi
70+
71+
# Build the artifact.
72+
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
73+
74+
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
75+
# run `auditwheel show` to verify manylinux compliance.
76+
if [[ "$os" == "linux" ]]; then
77+
./ci/utilities/run_auditwheel.sh
78+
fi
79+
80+
fi
81+
82+
else
83+
echo "Error: Invalid artifact: $artifact. Allowed values are: ${allowed_artifacts[@]}"
84+
exit 1
85+
fi

ci/envs/default.env

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,12 @@ export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-}
3434
export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0}
3535

3636
# Allows overriding the XLA commit that is used.
37-
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-}
37+
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-}
38+
39+
# Controls the location where the artifacts are written to.
40+
export JAXCI_OUTPUT_DIR="$(pwd)/dist"
41+
42+
# When enabled, artifacts will be built with RBE. Requires gcloud authentication
43+
# and only certain platforms support RBE. Therefore, this flag is enabled only
44+
# for CI builds where RBE is supported.
45+
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}

ci/utilities/run_auditwheel.sh

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
#
17+
# Runs auditwheel to verify manylinux compatibility.
18+
19+
# Get a list of all the wheels in the output directory. Only look for wheels
20+
# that need to be verified for manylinux compliance.
21+
WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" \))
22+
23+
if [[ -z "$WHEELS" ]]; then
24+
echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR"
25+
exit 1
26+
fi
27+
28+
for wheel in $WHEELS; do
29+
printf "\nRunning auditwheel on the following wheel:"
30+
ls $wheel
31+
OUTPUT_FULL=$(python -m auditwheel show $wheel)
32+
# Remove the wheel name from the output to avoid false positives.
33+
wheel_name=$(basename $wheel)
34+
OUTPUT=${OUTPUT_FULL//${wheel_name}/}
35+
36+
# If a wheel is manylinux2014 compliant, `auditwheel show` will return the
37+
# platform tag as manylinux_2_17. manylinux2014 is an alias for
38+
# manylinux_2_17.
39+
if echo "$OUTPUT" | grep -q "manylinux_2_17"; then
40+
printf "\n$wheel_name is manylinux2014 compliant.\n"
41+
else
42+
echo "$OUTPUT_FULL"
43+
printf "\n$wheel_name is NOT manylinux2014 compliant.\n"
44+
exit 1
45+
fi
46+
done

0 commit comments

Comments
 (0)