2929 type : string
3030 required : true
3131 default : " 0"
32+ install-jax-current-commit :
33+ description : " Should the 'jax' package be installed from the current commit?"
34+ type : string
35+ required : true
36+ default : " 1"
3237 gcs_download_uri :
3338 description : " GCS location prefix from where the artifacts should be downloaded"
3439 required : true
5762 JAXCI_HERMETIC_PYTHON_VERSION : " ${{ inputs.python }}"
5863 JAXCI_PYTHON : " python${{ inputs.python }}"
5964 JAXCI_ENABLE_X64 : " ${{ inputs.enable-x64 }}"
65+ JAXCI_INSTALL_JAX_CURRENT_COMMIT : " ${{ inputs.install-jax-current-commit }}"
6066
6167 steps :
6268 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -79,18 +85,47 @@ jobs:
7985 echo "ARCH=${arch}" >> $GITHUB_ENV
8086 echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
8187 - name : Download jaxlib wheel from GCS (non-Windows runs)
88+ id : download-wheel-artifacts-nw
89+ # Set continue-on-error to true to prevent actions from failing the workflow if this step
90+ # fails. Instead, we verify the outcome in the step below so that we can print a more
91+ # informative error message.
92+ continue-on-error : true
8293 if : ${{ !contains(inputs.runner, 'windows-x86') }}
83- run : >-
94+ run : |
8495 mkdir -p $(pwd)/dist &&
8596 gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
97+
98+ # Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
99+ if [[ "${{ inputs.install-jax-current-commit }}" != 1 ]]; then
100+ gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
101+ fi
86102 - name : Download jaxlib wheel from GCS (Windows runs)
103+ id : download-wheel-artifacts-w
104+ # Set continue-on-error to true to prevent actions from failing the workflow if this step
105+ # fails. Instead, we verify the outcome in step below so that we can print a more
106+ # informative error message.
107+ continue-on-error : true
87108 if : ${{ contains(inputs.runner, 'windows-x86') }}
88109 shell : cmd
89- run : >-
90- mkdir dist &&
91- gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
110+ run : |
111+ mkdir dist
112+ @REM Use `call` so that we can run sequential gsutil commands on Windows
113+ @REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
114+ call gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
115+
116+ @REM Download the "jax" wheel from GCS if inputs.install-jax-current-commit is not set to 1
117+ if not "${{ inputs.install-jax-current-commit }}"=="1" (
118+ call gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/
119+ )
120+ - name : Skip the test run if the wheel artifacts were not downloaded successfully
121+ if : steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
122+ run : |
123+ echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
124+ echo "built successfully by the artifact build jobs and are available in the GCS bucket."
125+ echo "Skipping the test run."
126+ exit 1
92127 - name : Install Python dependencies
93- run : $JAXCI_PYTHON -m pip install -r build/requirements.in
128+ run : $JAXCI_PYTHON -m uv pip install -r build/requirements.in
94129 # Halt for testing
95130 - name : Wait For Connection
96131 uses : google-ml-infra/actions/ci_connection@main
0 commit comments