Skip to content

Commit 34a2f0c

Browse files
nitins17Google-ML-Automation
authored andcommitted
Add a jaxlib at head build to the cloud-tpu-ci-nightly workflow
This will allow us to test TPU compatibility with jaxlib at head. Also, enable v4 runners as they are now online. PiperOrigin-RevId: 699155667
1 parent 73fa0f4 commit 34a2f0c

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ jobs:
2424
strategy:
2525
fail-fast: false # don't cancel all jobs on failure
2626
matrix:
27-
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
27+
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828
tpu: [
29-
# {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
30-
# {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
29+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
30+
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
3131
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
3232
]
3333
python-version: ["3.10"]
@@ -47,14 +47,34 @@ jobs:
4747
# mandates using a specific commit for non-Google actions. We use
4848
# https://github.com/sethvargo/ratchet to pin specific versions.
4949
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
50+
# Checkout XLA at head, if we're building jaxlib at head.
51+
- name: Checkout XLA at head
52+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
53+
if: ${{ matrix.jaxlib-version == 'head' }}
54+
with:
55+
repository: openxla/xla
56+
path: xla
5057
- name: Install JAX test requirements
5158
run: |
5259
$PYTHON -m pip install -U -r build/test-requirements.txt
5360
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5461
- name: Install JAX
5562
run: |
5663
$PYTHON -m pip uninstall -y jax jaxlib libtpu
57-
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
64+
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
65+
# Build and install jaxlib at head
66+
$PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
67+
--bazel_options="--override_repository=xla=$(pwd)/xla" \
68+
--bazel_options=--color=yes
69+
$PYTHON -m pip install dist/*.whl
70+
71+
# Install "jax" at head
72+
$PYTHON -m pip install -U -e .
73+
74+
# Install libtpu
75+
$PYTHON -m pip install --pre libtpu \
76+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
77+
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
5878
$PYTHON -m pip install .[tpu] \
5979
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6080

0 commit comments

Comments
 (0)