Skip to content

Commit a6ab6bb

Browse files
nitins17Google-ML-Automation
authored andcommitted
Ignore Pallas TPU tests when testing with the oldest supported libtpu
I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows PiperOrigin-RevId: 736094492
1 parent 61ba2b2 commit a6ab6bb

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

.github/workflows/pytest_tpu.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ jobs:
135135
echo "Using oldest supported libtpu"
136136
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
137137
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
138+
139+
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
138140
else
139141
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
140142
exit 1

ci/run_pytest_tpu.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,19 @@ export JAX_SKIP_SLOW_TESTS=true
5353
echo "Running TPU tests..."
5454

5555
if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
56+
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic
57+
# TPU does not guarantee anything about forward compatibility (unless
58+
# jax.export is used) and the 12 week compatibility window accumulates way
59+
# too many failures.
60+
IGNORE_FLAGS=""
61+
if [ "${libtpu_version_type:-""}" == "oldest_supported_libtpu" ]; then
62+
IGNORE_FLAGS="--ignore=tests/pallas"
63+
fi
64+
5665
# Run single-accelerator tests in parallel
5766
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
5867
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
59-
--maxfail=20 -m "not multiaccelerator" tests examples
68+
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
6069

6170
# Run Pallas printing tests, which need to run with I/O capturing disabled.
6271
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s \

0 commit comments

Comments
 (0)