|
| 1 | +# CI - Pytest TPU |
| 2 | +# |
| 3 | +# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via |
| 4 | +# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests. |
| 5 | +# |
| 6 | +# It consists of the following job: |
| 7 | +# run-tests: |
| 8 | +# - Downloads the jaxlib wheel from a GCS bucket. |
| 9 | +# - Sets up the libtpu wheels. |
| 10 | +# - Executes the `run_pytest_cpu.sh` script, which performs the following actions: |
| 11 | +# - Installs the downloaded jaxlib wheel. |
| 12 | +# - Runs the TPU tests with Pytest. |
| 13 | +name: CI - Pytest TPU |
| 14 | + |
| 15 | +on: |
| 16 | + workflow_call: |
| 17 | + inputs: |
| 18 | + # Note that the values for runners, cores, and tpu-type are linked to each other. |
| 19 | + # For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the |
| 20 | + # following mapping: |
| 21 | + # {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, |
| 22 | + # {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} |
| 23 | + runner: |
| 24 | + description: "Which runner should the workflow run on?" |
| 25 | + type: string |
| 26 | + required: true |
| 27 | + default: "linux-x86-ct5lp-224-8tpu" |
| 28 | + cores: |
| 29 | + description: "How many TPU cores should the test use?" |
| 30 | + type: string |
| 31 | + required: true |
| 32 | + default: "8" |
| 33 | + tpu-type: |
| 34 | + description: "Which TPU type is used for testing?" |
| 35 | + type: string |
| 36 | + required: true |
| 37 | + default: "v5e-8" |
| 38 | + python: |
| 39 | + description: "Which Python version should be used for testing?" |
| 40 | + type: string |
| 41 | + required: true |
| 42 | + default: "3.12" |
| 43 | + run-full-tpu-test-suite: |
| 44 | + description: "Should the full TPU test suite be run?" |
| 45 | + type: string |
| 46 | + required: false |
| 47 | + default: "0" |
| 48 | + libtpu-version-type: |
| 49 | + description: "Which libtpu version should be used for testing?" |
| 50 | + type: string |
| 51 | + required: false |
| 52 | + # Choices are: |
| 53 | + # - "nightly": Use the nightly libtpu wheel. |
| 54 | + # - "pypi_latest": Use the latest libtpu wheel from PyPI. |
| 55 | + # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. |
| 56 | + default: "nightly" |
| 57 | + gcs_download_uri: |
| 58 | + description: "GCS location prefix from where the artifacts should be downloaded" |
| 59 | + required: true |
| 60 | + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' |
| 61 | + type: string |
| 62 | + halt-for-connection: |
| 63 | + description: 'Should this workflow run wait for a remote connection?' |
| 64 | + type: boolean |
| 65 | + required: false |
| 66 | + default: false |
| 67 | + |
| 68 | +jobs: |
| 69 | + run-tests: |
| 70 | + defaults: |
| 71 | + run: |
| 72 | + shell: bash |
| 73 | + runs-on: ${{ inputs.runner }} |
| 74 | + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" |
| 75 | + # Begin Presubmit Naming Check - name modification requires internal check to be updated |
| 76 | + name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" |
| 77 | + # End Presubmit Naming Check github-tpu-presubmits |
| 78 | + |
| 79 | + env: |
| 80 | + LIBTPU_OLDEST_VERSION_DATE: 20241205 |
| 81 | + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" |
| 82 | + JAXCI_PYTHON: "python${{ inputs.python }}" |
| 83 | + JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" |
| 84 | + JAXCI_TPU_CORES: "${{ inputs.cores }}" |
| 85 | + |
| 86 | + steps: |
| 87 | + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 |
| 88 | + - name: Set env vars for use in artifact download URL |
| 89 | + run: | |
| 90 | + os=$(uname -s | awk '{print tolower($0)}') |
| 91 | + arch=$(uname -m) |
| 92 | +
|
| 93 | + # Get the major and minor version of Python. |
| 94 | + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 |
| 95 | + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t |
| 96 | + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') |
| 97 | +
|
| 98 | + echo "OS=${os}" >> $GITHUB_ENV |
| 99 | + echo "ARCH=${arch}" >> $GITHUB_ENV |
| 100 | + # Python wheels follow a naming convention: standard wheels use the pattern |
| 101 | + # `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use |
| 102 | + # `*-cp<py_version>-cp<py_version>t-*`. |
| 103 | + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV |
| 104 | + - name: Download JAX wheels from GCS |
| 105 | + id: download-wheel-artifacts |
| 106 | + # Set continue-on-error to true to prevent actions from failing the workflow if this step |
| 107 | + # fails. Instead, we verify the outcome in the step below so that we can print a more |
| 108 | + # informative error message. |
| 109 | + continue-on-error: true |
| 110 | + run: | |
| 111 | + mkdir -p $(pwd)/dist |
| 112 | + gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ |
| 113 | + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ |
| 114 | + - name: Skip the test run if the wheel artifacts were not downloaded successfully |
| 115 | + if: steps.download-wheel-artifacts.outcome == 'failure' |
| 116 | + run: | |
| 117 | + echo "Failed to download wheel artifacts from GCS. Please check if the wheels were" |
| 118 | + echo "built successfully by the artifact build jobs and are available in the GCS bucket." |
| 119 | + echo "Skipping the test run." |
| 120 | + exit 1 |
| 121 | + - name: Install Python dependencies |
| 122 | + run: | |
| 123 | + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt |
| 124 | + - name: Set up libtpu wheels |
| 125 | + run: | |
| 126 | + if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then |
| 127 | + echo "Using nightly libtpu" |
| 128 | + $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
| 129 | + elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then |
| 130 | + echo "Using latest libtpu from PyPI" |
| 131 | + # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` |
| 132 | + # script will install the latest libtpu wheel from PyPI. |
| 133 | + echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV |
| 134 | + elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then |
| 135 | + echo "Using oldest supported libtpu" |
| 136 | + $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ |
| 137 | + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
| 138 | +
|
| 139 | + echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV |
| 140 | + else |
| 141 | + echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" |
| 142 | + exit 1 |
| 143 | + fi |
| 144 | + # Halt for testing |
| 145 | + - name: Wait For Connection |
| 146 | + uses: google-ml-infra/actions/ci_connection@main |
| 147 | + with: |
| 148 | + halt-dispatch-input: ${{ inputs.halt-for-connection }} |
| 149 | + - name: Run Pytest TPU tests |
| 150 | + timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }} |
| 151 | + run: ./ci/run_pytest_tpu.sh |
0 commit comments