Skip to content

Commit 9cc5452

Browse files
Merge pull request #276 from ROCm/ci-upstream-sync-144_1
CI: 03/12/25 upstream sync
2 parents f14a1d0 + db8ba1b commit 9cc5452

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1793
-457
lines changed

.bazelrc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,6 @@ build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base
253253
build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc
254254
build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
255255

256-
# Mac x86 CI configs
257-
build:ci_darwin_x86_64 --macos_minimum_os=11.0
258-
build:ci_darwin_x86_64 --config=macos_cache_push
259-
build:ci_darwin_x86_64 --verbose_failures=true
260-
build:ci_darwin_x86_64 --color=yes
261-
262256
# Mac Arm64 CI configs
263257
build:ci_darwin_arm64 --macos_minimum_os=11.0
264258
build:ci_darwin_arm64 --config=macos_cache_push

.github/workflows/cloud-tpu-ci-presubmit.yml

Lines changed: 27 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
44
# tested to get to a stable state before we enable it as a blocking presubmit.
55
name: CI - Cloud TPU (presubmit)
6+
67
on:
78
workflow_dispatch:
89
inputs:
@@ -33,64 +34,32 @@ concurrency:
3334
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
3435

3536
jobs:
36-
cloud-tpu-test:
37+
build-jax-artifacts:
3738
if: github.event.repository.fork == false
38-
# Begin Presubmit Naming Check - name modification requires internal check to be updated
39+
uses: ./.github/workflows/build_artifacts.yml
3940
strategy:
40-
fail-fast: false # don't cancel all jobs on failure
41-
matrix:
42-
tpu: [
43-
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
44-
]
45-
python-version: ["3.10"]
46-
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
47-
# End Presubmit Naming Check github-tpu-presubmits
48-
env:
49-
JAXCI_PYTHON: python${{ matrix.python-version }}
50-
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
51-
52-
runs-on: ${{ matrix.tpu.runner }}
53-
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
54-
55-
timeout-minutes: 60
41+
fail-fast: false # don't cancel all jobs on failure
42+
matrix:
43+
artifact: ["jax", "jaxlib"]
44+
with:
45+
runner: "linux-x86-n2-16"
46+
artifact: ${{ matrix.artifact }}
47+
python: "3.10"
48+
clone_main_xla: 1
49+
upload_artifacts_to_gcs: true
50+
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
5651

57-
defaults:
58-
run:
59-
shell: bash -ex {0}
60-
steps:
61-
# https://opensource.google/documentation/reference/github/services#actions
62-
# mandates using a specific commit for non-Google actions. We use
63-
# https://github.com/sethvargo/ratchet to pin specific versions.
64-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
65-
# Checkout XLA at head, if we're building jaxlib at head.
66-
- name: Checkout XLA at head
67-
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
68-
with:
69-
repository: openxla/xla
70-
path: xla
71-
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
72-
- name: Mark GitHub workspace as safe
73-
run: |
74-
git config --global --add safe.directory "$GITHUB_WORKSPACE"
75-
- name: Install JAX test requirements
76-
run: |
77-
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
78-
- name: Build jaxlib at head with latest XLA
79-
run: |
80-
# Build and install jaxlib at head
81-
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
82-
--python_version=${{ matrix.python-version }} \
83-
--bazel_options=--config=rbe_linux_x86_64 \
84-
--local_xla_path="$(pwd)/xla" \
85-
--verbose
86-
87-
# Install libtpu
88-
$JAXCI_PYTHON -m uv pip install --pre libtpu \
89-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
90-
# Halt for testing
91-
- name: Wait For Connection
92-
uses: google-ml-infra/actions/ci_connection@main
93-
with:
94-
halt-dispatch-input: ${{ inputs.halt-for-connection }}
95-
- name: Install jaxlib wheel and run tests
96-
run: ./ci/run_pytest_tpu.sh
52+
run-pytest-tpu:
53+
if: github.event.repository.fork == false
54+
needs: [build-jax-artifacts]
55+
uses: ./.github/workflows/pytest_tpu.yml
56+
# Begin Presubmit Naming Check - name modification requires internal check to be updated
57+
name: "TPU test (jaxlib=head, v5e-8)"
58+
with:
59+
runner: "linux-x86-ct5lp-224-8tpu"
60+
cores: "8"
61+
tpu-type: "v5e-8"
62+
python: "3.10"
63+
libtpu-version-type: "nightly"
64+
gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }}
65+
# End Presubmit Naming Check github-tpu-presubmits

.github/workflows/pytest_cpu.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ jobs:
116116
exit 1
117117
- name: Install Python dependencies
118118
run: |
119+
# Remove installation of NVIDIA wheels for CPU tests.
120+
sed -i 's/-r gpu-test-requirements.txt/# -r gpu-test-requirements.txt/g' build/requirements.in
121+
119122
# TODO(srnitin): Remove after uv is installed in the Windows Dockerfile
120123
$JAXCI_PYTHON -m pip install uv~=0.5.30
121124
# python 3.13t cannot compile zstandard 0.23.0 due to

.github/workflows/pytest_tpu.yml

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# CI - Pytest TPU
2+
#
3+
# This workflow runs the TPU tests with Pytest. It can only be triggered by other workflows via
4+
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest TPU tests.
5+
#
6+
# It consists of the following job:
7+
# run-tests:
8+
# - Downloads the jaxlib wheel from a GCS bucket.
9+
# - Sets up the libtpu wheels.
10+
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions:
11+
# - Installs the downloaded jaxlib wheel.
12+
# - Runs the TPU tests with Pytest.
13+
name: CI - Pytest TPU
14+
15+
on:
16+
workflow_call:
17+
inputs:
18+
# Note that the values for runners, cores, and tpu-type are linked to each other.
19+
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
20+
# following mapping:
21+
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
22+
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
23+
runner:
24+
description: "Which runner should the workflow run on?"
25+
type: string
26+
required: true
27+
default: "linux-x86-ct5lp-224-8tpu"
28+
cores:
29+
description: "How many TPU cores should the test use?"
30+
type: string
31+
required: true
32+
default: "8"
33+
tpu-type:
34+
description: "Which TPU type is used for testing?"
35+
type: string
36+
required: true
37+
default: "v5e-8"
38+
python:
39+
description: "Which Python version should be used for testing?"
40+
type: string
41+
required: true
42+
default: "3.12"
43+
run-full-tpu-test-suite:
44+
description: "Should the full TPU test suite be run?"
45+
type: string
46+
required: false
47+
default: "0"
48+
libtpu-version-type:
49+
description: "Which libtpu version should be used for testing?"
50+
type: string
51+
required: false
52+
# Choices are:
53+
# - "nightly": Use the nightly libtpu wheel.
54+
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
55+
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
56+
default: "nightly"
57+
gcs_download_uri:
58+
description: "GCS location prefix from where the artifacts should be downloaded"
59+
required: true
60+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
61+
type: string
62+
halt-for-connection:
63+
description: 'Should this workflow run wait for a remote connection?'
64+
type: boolean
65+
required: false
66+
default: false
67+
68+
jobs:
69+
run-tests:
70+
defaults:
71+
run:
72+
shell: bash
73+
runs-on: ${{ inputs.runner }}
74+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
75+
# Begin Presubmit Naming Check - name modification requires internal check to be updated
76+
name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})"
77+
# End Presubmit Naming Check github-tpu-presubmits
78+
79+
env:
80+
LIBTPU_OLDEST_VERSION_DATE: 20241205
81+
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
82+
JAXCI_PYTHON: "python${{ inputs.python }}"
83+
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
84+
JAXCI_TPU_CORES: "${{ inputs.cores }}"
85+
86+
steps:
87+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
88+
- name: Set env vars for use in artifact download URL
89+
run: |
90+
os=$(uname -s | awk '{print tolower($0)}')
91+
arch=$(uname -m)
92+
93+
# Get the major and minor version of Python.
94+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
95+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
96+
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')
97+
98+
echo "OS=${os}" >> $GITHUB_ENV
99+
echo "ARCH=${arch}" >> $GITHUB_ENV
100+
# Python wheels follow a naming convention: standard wheels use the pattern
101+
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
102+
# `*-cp<py_version>-cp<py_version>t-*`.
103+
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
104+
- name: Download JAX wheels from GCS
105+
id: download-wheel-artifacts
106+
# Set continue-on-error to true to prevent actions from failing the workflow if this step
107+
# fails. Instead, we verify the outcome in the step below so that we can print a more
108+
# informative error message.
109+
continue-on-error: true
110+
run: |
111+
mkdir -p $(pwd)/dist
112+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
113+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
114+
- name: Skip the test run if the wheel artifacts were not downloaded successfully
115+
if: steps.download-wheel-artifacts.outcome == 'failure'
116+
run: |
117+
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
118+
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
119+
echo "Skipping the test run."
120+
exit 1
121+
- name: Install Python dependencies
122+
run: |
123+
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
124+
- name: Set up libtpu wheels
125+
run: |
126+
if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then
127+
echo "Using nightly libtpu"
128+
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
129+
elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then
130+
echo "Using latest libtpu from PyPI"
131+
# Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh`
132+
# script will install the latest libtpu wheel from PyPI.
133+
echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV
134+
elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then
135+
echo "Using oldest supported libtpu"
136+
$JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
137+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
138+
139+
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
140+
else
141+
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
142+
exit 1
143+
fi
144+
# Halt for testing
145+
- name: Wait For Connection
146+
uses: google-ml-infra/actions/ci_connection@main
147+
with:
148+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
149+
- name: Run Pytest TPU tests
150+
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 180 }}
151+
run: ./ci/run_pytest_tpu.sh

.github/workflows/wheel_tests_continuous.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,30 @@ jobs:
142142
python: ${{ matrix.python }}
143143
enable-x64: ${{ matrix.enable-x64 }}
144144
# GCS upload URI is the same for both artifact build jobs
145+
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
146+
147+
run-pytest-tpu:
148+
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
149+
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
150+
# still want to run the tests for other platforms.
151+
if: ${{ !cancelled() }}
152+
needs: [build-jax-artifact, build-jaxlib-artifact]
153+
uses: ./.github/workflows/pytest_tpu.yml
154+
strategy:
155+
fail-fast: false # don't cancel all jobs on failure
156+
matrix:
157+
python: ["3.10",]
158+
tpu-specs: [
159+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
160+
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
161+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
162+
]
163+
name: "TPU tests (jax=head, jaxlib=head)"
164+
with:
165+
runner: ${{ matrix.tpu-specs.runner }}
166+
cores: ${{ matrix.tpu-specs.cores }}
167+
tpu-type: ${{ matrix.tpu-specs.type }}
168+
python: ${{ matrix.python }}
169+
run-full-tpu-test-suite: "1"
170+
libtpu-version-type: "nightly"
145171
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

.github/workflows/wheel_tests_nightly_release.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,42 @@ jobs:
5858
python: ${{ matrix.python }}
5959
cuda: ${{ matrix.cuda }}
6060
enable-x64: ${{ matrix.enable-x64 }}
61+
gcs_download_uri: ${{inputs.gcs_download_uri}}
62+
63+
run-pytest-tpu:
64+
uses: ./.github/workflows/pytest_tpu.yml
65+
strategy:
66+
fail-fast: false # don't cancel all jobs on failure
67+
matrix:
68+
# Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for
69+
# profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302)
70+
python: ["3.10", "3.11", "3.12"]
71+
tpu-specs: [
72+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
73+
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
74+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
75+
]
76+
libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"]
77+
exclude:
78+
- libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }}
79+
- libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }}
80+
# Run a single Python version for v4-8.
81+
- tpu-specs:
82+
type: "v4-8"
83+
python: "3.10"
84+
- tpu-specs:
85+
type: "v4-8"
86+
python: "3.11"
87+
# Run min and max Python versions for v5e-8
88+
- tpu-specs:
89+
type: "v5e-8"
90+
python: "3.11"
91+
name: "TPU tests (jax=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, jaxlib=${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
92+
with:
93+
runner: ${{ matrix.tpu-specs.runner }}
94+
cores: ${{ matrix.tpu-specs.cores }}
95+
tpu-type: ${{ matrix.tpu-specs.type }}
96+
python: ${{ matrix.python }}
97+
run-full-tpu-test-suite: "1"
98+
libtpu-version-type: ${{ matrix.libtpu-version-type }}
6199
gcs_download_uri: ${{inputs.gcs_download_uri}}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,4 @@ For details about the JAX API, see the
456456

457457
For getting started as a JAX developer, see the
458458
[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
459+

build/BUILD.bazel

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ compile_pip_requirements(
2929
requirements_in = "requirements.in",
3030
requirements_txt = REQUIREMENTS,
3131
generate_hashes = True,
32-
data = ["test-requirements.txt"]
32+
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
3333
)
3434

3535
compile_pip_requirements(
@@ -44,7 +44,7 @@ compile_pip_requirements(
4444
requirements_in = "requirements.in",
4545
requirements_txt = REQUIREMENTS,
4646
generate_hashes = False,
47-
data = ["test-requirements.txt"]
47+
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
4848
)
4949

5050
compile_pip_requirements(
@@ -58,7 +58,7 @@ compile_pip_requirements(
5858
requirements_in = "requirements.in",
5959
requirements_txt = REQUIREMENTS,
6060
generate_hashes = False,
61-
data = ["test-requirements.txt"]
61+
data = ["test-requirements.txt", "gpu-test-requirements.txt"]
6262
)
6363

6464
py_library(

build/gpu-test-requirements.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# NVIDIA CUDA dependencies
2+
# Note that the wheels are downloaded only when the targets in bazel command
3+
# contain dependencies on these wheels.
4+
nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux"
5+
nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux"
6+
nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux"
7+
nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux"
8+
nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux"
9+
nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux"
10+
nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux"
11+
nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux"
12+
nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux"
13+
nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux"

0 commit comments

Comments
 (0)