1313name : CI - Cloud TPU (nightly)
1414on :
1515 schedule :
16- - cron : " 0 14 * * *" # daily at 7am PST
16+ - cron : " * */2 * * *" # Run every 2 hours
1717 workflow_dispatch : # allows triggering the workflow run manually
1818# This should also be set to read-only in the project settings, but it's nice to
1919# document and enforce the permissions here.
@@ -26,15 +26,18 @@ jobs:
2626 matrix :
2727 jaxlib-version : ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828 tpu : [
29- {type: "v3-8", cores: "4"},
30- {type: "v4-8", cores: "4"},
31- {type: "v5e-8", cores: "8"}
29+ # {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
30+ # {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu "},
31+ {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu" }
3232 ]
33+ python-version : ["3.10"]
3334 name : " TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3435 env :
3536 LIBTPU_OLDEST_VERSION_DATE : 20240722
3637 ENABLE_PJRT_COMPATIBILITY : ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
37- runs-on : ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
38+ PYTHON : python${{ matrix.python-version }}
39+ runs-on : ${{ matrix.tpu.runner }}
40+ container : " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
3841 timeout-minutes : 120
3942 defaults :
4043 run :
@@ -46,52 +49,52 @@ jobs:
4649 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
4750 - name : Install JAX test requirements
4851 run : |
49- pip install -U -r build/test-requirements.txt
50- pip install -U -r build/collect-profile-requirements.txt
52+ $PYTHON -m pip install -U -r build/test-requirements.txt
53+ $PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5154 - name : Install JAX
5255 run : |
53- pip uninstall -y jax jaxlib libtpu
56+ $PYTHON -m pip uninstall -y jax jaxlib libtpu
5457 if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
55- pip install .[tpu] \
58+ $PYTHON -m pip install .[tpu] \
5659 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5760
5861 elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
59- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60- pip install --pre libtpu \
62+ $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
63+ $PYTHON -m pip install --pre libtpu \
6164 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
62- pip install requests
65+ $PYTHON -m pip install requests
6366
6467 elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
68+ $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6669 # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67- pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
70+ $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6871 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69- pip install requests
72+ $PYTHON -m pip install requests
7073 else
7174 echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7275 exit 1
7376 fi
7477
75- python3 -c 'import sys; print("python version:", sys.version)'
76- python3 -c 'import jax; print("jax version:", jax.__version__)'
77- python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
78- strings $HOME/. local/lib/python3.10/site -packages/libtpu/libtpu.so | grep 'Built on'
79- python3 -c 'import jax; print("libtpu version:",
78+ $PYTHON -c 'import sys; print("python version:", sys.version)'
79+ $PYTHON -c 'import jax; print("jax version:", jax.__version__)'
80+ $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
81+ strings /usr/ local/lib/"$PYTHON"/dist -packages/libtpu/libtpu.so | grep 'Built on'
82+ $PYTHON -c 'import jax; print("libtpu version:",
8083 jax.lib.xla_bridge.get_backend().platform_version)'
8184 - name : Run tests
8285 env :
8386 JAX_PLATFORMS : tpu,cpu
8487 PY_COLORS : 1
8588 run : |
8689 # Run single-accelerator tests in parallel
87- JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
90+ JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
8891 --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
8992 --maxfail=20 -m "not multiaccelerator" tests examples
9093 # Run Pallas printing tests, which need to run with I/O capturing disabled.
91- TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
94+ TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
9295 tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
9396 # Run multi-accelerator across all chips
94- python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
97+ $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
9598 - name : Send chat on failure
9699 # Don't notify when testing the workflow from a branch.
97100 if : ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
0 commit comments