File tree Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -135,6 +135,8 @@ jobs:
135135 echo "Using oldest supported libtpu"
136136 $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
137137 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
138+
139+ echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
138140 else
139141 echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
140142 exit 1
Original file line number Diff line number Diff line change @@ -53,10 +53,19 @@ export JAX_SKIP_SLOW_TESTS=true
5353echo " Running TPU tests..."
5454
5555if [[ " $JAXCI_RUN_FULL_TPU_TEST_SUITE " == " 1" ]]; then
56+ # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic
57+ # TPU does not guarantee anything about forward compatibility (unless
58+ # jax.export is used) and the 12 week compatibility window accumulates way
59+ # too many failures.
60+ IGNORE_FLAGS=" "
61+ if [ " ${libtpu_version_type:- " " } " == " oldest_supported_libtpu" ]; then
62+ IGNORE_FLAGS=" --ignore=tests/pallas"
63+ fi
64+
5665 # Run single-accelerator tests in parallel
5766 JAX_ENABLE_TPU_XDIST=true " $JAXCI_PYTHON " -m pytest -n=" $JAXCI_TPU_CORES " --tb=short \
5867 --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
59- --maxfail=20 -m " not multiaccelerator" tests examples
68+ --maxfail=20 -m " not multiaccelerator" $IGNORE_FLAGS tests examples
6069
6170 # Run Pallas printing tests, which need to run with I/O capturing disabled.
6271 TPU_STDERR_LOG_LEVEL=0 " $JAXCI_PYTHON " -m pytest -s \
You can’t perform that action at this time.
0 commit comments