Skip to content

Commit 7d7a0fa

Browse files
nitins17Google-ML-Automation
authored andcommitted
Run the TPU workflow on new self-hosted runners
We are not able to run the TPU workflows because of no active runners (https://github.com/jax-ml/jax/actions/runs/11879479226/job/33101456081). So this adds the new self-hosted runners to the TPU workflow to fix this issue. The v3 type is disabled as we do not have that available yet. PiperOrigin-RevId: 698772505
1 parent 1bc9df4 commit 7d7a0fa

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
name: CI - Cloud TPU (nightly)
1414
on:
1515
schedule:
16-
- cron: "0 14 * * *" # daily at 7am PST
16+
- cron: "* */2 * * *" # Run every 2 hours
1717
workflow_dispatch: # allows triggering the workflow run manually
1818
# This should also be set to read-only in the project settings, but it's nice to
1919
# document and enforce the permissions here.
@@ -26,15 +26,18 @@ jobs:
2626
matrix:
2727
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828
tpu: [
29-
{type: "v3-8", cores: "4"},
30-
{type: "v4-8", cores: "4"},
31-
{type: "v5e-8", cores: "8"}
29+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
30+
# {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
31+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
3232
]
33+
python-version: ["3.10"]
3334
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3435
env:
3536
LIBTPU_OLDEST_VERSION_DATE: 20240722
3637
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
37-
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
38+
PYTHON: python${{ matrix.python-version }}
39+
runs-on: ${{ matrix.tpu.runner }}
40+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
3841
timeout-minutes: 120
3942
defaults:
4043
run:
@@ -46,52 +49,52 @@ jobs:
4649
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
4750
- name: Install JAX test requirements
4851
run: |
49-
pip install -U -r build/test-requirements.txt
50-
pip install -U -r build/collect-profile-requirements.txt
52+
$PYTHON -m pip install -U -r build/test-requirements.txt
53+
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5154
- name: Install JAX
5255
run: |
53-
pip uninstall -y jax jaxlib libtpu
56+
$PYTHON -m pip uninstall -y jax jaxlib libtpu
5457
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
55-
pip install .[tpu] \
58+
$PYTHON -m pip install .[tpu] \
5659
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5760
5861
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
59-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60-
pip install --pre libtpu \
62+
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
63+
$PYTHON -m pip install --pre libtpu \
6164
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
62-
pip install requests
65+
$PYTHON -m pip install requests
6366
6467
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
68+
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6669
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67-
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
70+
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6871
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69-
pip install requests
72+
$PYTHON -m pip install requests
7073
else
7174
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7275
exit 1
7376
fi
7477
75-
python3 -c 'import sys; print("python version:", sys.version)'
76-
python3 -c 'import jax; print("jax version:", jax.__version__)'
77-
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
78-
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
79-
python3 -c 'import jax; print("libtpu version:",
78+
$PYTHON -c 'import sys; print("python version:", sys.version)'
79+
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
80+
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
81+
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
82+
$PYTHON -c 'import jax; print("libtpu version:",
8083
jax.lib.xla_bridge.get_backend().platform_version)'
8184
- name: Run tests
8285
env:
8386
JAX_PLATFORMS: tpu,cpu
8487
PY_COLORS: 1
8588
run: |
8689
# Run single-accelerator tests in parallel
87-
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
90+
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
8891
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
8992
--maxfail=20 -m "not multiaccelerator" tests examples
9093
# Run Pallas printing tests, which need to run with I/O capturing disabled.
91-
TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
94+
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
9295
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
9396
# Run multi-accelerator across all chips
94-
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
97+
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
9598
- name: Send chat on failure
9699
# Don't notify when testing the workflow from a branch.
97100
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}

0 commit comments

Comments
 (0)