Skip to content

Commit a366d41

Browse files
Merge pull request #213 from ROCm/ci-upstream-sync-97_1
CI: 01/27/25 upstream sync
2 parents 653f773 + 41ab12b commit a366d41

File tree

292 files changed

+14398
-4195
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

292 files changed

+14398
-4195
lines changed

.bazelrc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ build:cuda --repo_env TF_NCCL_USE_STUB=1
120120
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
121121
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
122122
build:cuda --@local_config_cuda//:enable_cuda
123-
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
124123

125124
# Default hermetic CUDA and CUDNN versions.
126125
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
@@ -158,7 +157,7 @@ build:win_clang --compiler=clang-cl
158157
build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
159158
build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true
160159
build:rocm_base --repo_env TF_NEED_ROCM=1
161-
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100"
160+
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201"
162161

163162
# Build with hipcc for ROCm and clang for the host.
164163
build:rocm --config=rocm_base
@@ -171,6 +170,12 @@ build:rocm --action_env=TF_HIPCC_CLANG="1"
171170
# #############################################################################
172171
# Cache options below.
173172
# #############################################################################
173+
# Public read-only cache
174+
build:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false
175+
# Cache pushes are limited to JAX's CI system.
176+
build:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials
177+
178+
# Note: the following cache configs are deprecated and will be removed soon.
174179
# Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache
175180
# from JAX's Mac CI build. By applying --config=macos_cache, any local Mac build
176181
# should be able to read from this cache and potentially see a speedup. The

.github/workflows/bazel_cpu_rbe.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
3939

4040
steps:
41-
- uses: actions/checkout@v3
41+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
4242
- name: Wait For Connection
4343
uses: google-ml-infra/actions/ci_connection@main
4444
with:
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# CI - Bazel CUDA tests (Non-RBE)
2+
#
3+
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
4+
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Bazel CUDA tests.
5+
#
6+
# It consists of the following job:
7+
# run-tests:
8+
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
9+
# - Executes the `run_bazel_test_cuda_non_rbe.sh` script, which performs the following actions:
10+
# - Installs the downloaded wheel artifacts.
11+
# - Runs the CUDA tests with Bazel.
12+
name: CI - Bazel CUDA tests (Non-RBE)
13+
14+
on:
15+
workflow_call:
16+
inputs:
17+
runner:
18+
description: "Which runner should the workflow run on?"
19+
type: string
20+
required: true
21+
default: "linux-x86-n2-16"
22+
python:
23+
description: "Which python version to test?"
24+
type: string
25+
required: true
26+
default: "3.12"
27+
enable-x64:
28+
description: "Should x64 mode be enabled?"
29+
type: string
30+
required: true
31+
default: "0"
32+
gcs_download_uri:
33+
description: "GCS location URI from where the artifacts should be downloaded"
34+
required: true
35+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
36+
type: string
37+
halt-for-connection:
38+
description: 'Should this workflow run wait for a remote connection?'
39+
type: boolean
40+
required: false
41+
default: false
42+
43+
jobs:
44+
run-tests:
45+
runs-on: ${{ inputs.runner }}
46+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
47+
48+
env:
49+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
50+
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
51+
# Enable writing to the Bazel remote cache bucket.
52+
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1"
53+
54+
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
55+
56+
steps:
57+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
58+
- name: Set env vars for use in artifact download URL
59+
run: |
60+
os=$(uname -s | awk '{print tolower($0)}')
61+
arch=$(uname -m)
62+
63+
# Get the major and minor version of Python.
64+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
65+
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
66+
67+
echo "OS=${os}" >> $GITHUB_ENV
68+
echo "ARCH=${arch}" >> $GITHUB_ENV
69+
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
70+
- name: Download the wheel artifacts from GCS
71+
run: >-
72+
mkdir -p $(pwd)/dist &&
73+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
74+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
75+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
76+
# Halt for testing
77+
- name: Wait For Connection
78+
uses: google-ml-infra/actions/ci_connection@main
79+
with:
80+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
81+
- name: Run Bazel CUDA tests (Non-RBE)
82+
timeout-minutes: 60
83+
run: ./ci/run_bazel_test_cuda_non_rbe.sh

.github/workflows/bazel_gpu_rbe.yml renamed to .github/workflows/bazel_cuda_rbe.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI - Bazel GPU tests (RBE)
1+
name: CI - Bazel CUDA tests (RBE)
22

33
on:
44
workflow_dispatch:
@@ -34,13 +34,13 @@ jobs:
3434
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
3535
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
3636

37-
name: "Bazel single accelerator GPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
37+
name: "Bazel single accelerator CUDA tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
3838

3939
steps:
40-
- uses: actions/checkout@v3
40+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
4141
- name: Wait For Connection
4242
uses: google-ml-infra/actions/ci_connection@main
4343
with:
4444
halt-dispatch-input: ${{ inputs.halt-for-connection }}
45-
- name: Run Bazel GPU Tests with RBE
46-
run: ./ci/run_bazel_test_gpu_rbe.sh
45+
- name: Run Bazel CUDA Tests with RBE
46+
run: ./ci/run_bazel_test_cuda_rbe.sh
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# CI - Build JAX Artifacts
2+
# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) with a set of
3+
# configuration options (platform, python version, whether to use latest XLA, etc). It can be
4+
# triggered manually via workflow_dispatch or called by other workflows via workflow_call. When a
5+
# workflow call is made, this workflow will build the artifacts and upload it to a GCS bucket so
6+
# that other workflows (e.g. Pytest workflows) can use it.
7+
name: CI - Build JAX Artifacts
8+
9+
on:
10+
workflow_dispatch:
11+
inputs:
12+
runner:
13+
description: "Which runner should the workflow run on?"
14+
type: choice
15+
required: true
16+
default: "linux-x86-n2-16"
17+
options:
18+
- "linux-x86-n2-16"
19+
- "linux-arm64-c4a-64"
20+
- "windows-x86-n2-16"
21+
artifact:
22+
description: "Which JAX artifact to build?"
23+
type: choice
24+
required: true
25+
default: "jaxlib"
26+
options:
27+
- "jax"
28+
- "jaxlib"
29+
- "jax-cuda-plugin"
30+
- "jax-cuda-pjrt"
31+
python:
32+
description: "Which python version should the artifact be built for?"
33+
type: choice
34+
required: false
35+
default: "3.12"
36+
options:
37+
- "3.10"
38+
- "3.11"
39+
- "3.12"
40+
- "3.13"
41+
clone_main_xla:
42+
description: "Should latest XLA be used?"
43+
type: choice
44+
required: false
45+
default: "0"
46+
options:
47+
- "1"
48+
- "0"
49+
halt-for-connection:
50+
description: 'Should this workflow run wait for a remote connection?'
51+
type: choice
52+
required: false
53+
default: 'no'
54+
options:
55+
- 'yes'
56+
- 'no'
57+
workflow_call:
58+
inputs:
59+
runner:
60+
description: "Which runner should the workflow run on?"
61+
type: string
62+
required: true
63+
default: "linux-x86-n2-16"
64+
artifact:
65+
description: "Which JAX artifact to build?"
66+
type: string
67+
required: true
68+
default: "jaxlib"
69+
python:
70+
description: "Which python version should the artifact be built for?"
71+
type: string
72+
required: false
73+
default: "3.12"
74+
clone_main_xla:
75+
description: "Should latest XLA be used?"
76+
type: string
77+
required: false
78+
default: "0"
79+
upload_artifacts_to_gcs:
80+
description: "Should the artifacts be uploaded to a GCS bucket?"
81+
required: true
82+
default: true
83+
type: boolean
84+
gcs_upload_uri:
85+
description: "GCS location prefix to where the artifacts should be uploaded"
86+
required: true
87+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
88+
type: string
89+
outputs:
90+
gcs_upload_uri:
91+
description: "GCS location prefix to where the artifacts were uploaded"
92+
value: ${{ jobs.build-artifacts.outputs.gcs_upload_uri }}
93+
94+
permissions:
95+
contents: read
96+
97+
jobs:
98+
build-artifacts:
99+
defaults:
100+
run:
101+
# Explicitly set the shell to bash to override Windows's default (cmd)
102+
shell: bash
103+
104+
runs-on: ${{ inputs.runner }}
105+
106+
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
107+
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
108+
(contains(inputs.runner, 'windows-x86') && null) }}
109+
110+
env:
111+
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
112+
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
113+
114+
name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }})
115+
116+
# Map the job outputs to step outputs
117+
outputs:
118+
gcs_upload_uri: ${{ steps.store-gcs-upload-uri.outputs.gcs_upload_uri }}
119+
120+
steps:
121+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
122+
- name: Enable RBE if building on Linux x86 or Windows x86
123+
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
124+
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
125+
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
126+
if: contains(inputs.runner, 'linux-arm64')
127+
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
128+
# Halt for testing
129+
- name: Wait For Connection
130+
uses: google-ml-infra/actions/ci_connection@main
131+
with:
132+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
133+
- name: Build ${{ inputs.artifact }}
134+
timeout-minutes: 30
135+
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}"
136+
- name: Upload artifacts to a GCS bucket (non-Windows runs)
137+
if: >-
138+
${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }}
139+
run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/
140+
# Set shell to cmd to avoid path errors when using gcloud commands on Windows
141+
- name: Upload artifacts to a GCS bucket (Windows runs)
142+
if: >-
143+
${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }}
144+
shell: cmd
145+
run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/
146+
- name: Store the GCS upload URI as an output
147+
id: store-gcs-upload-uri
148+
if: ${{ inputs.upload_artifacts_to_gcs }}
149+
run: echo "gcs_upload_uri=${{ inputs.gcs_upload_uri }}" >> "$GITHUB_OUTPUT"

.github/workflows/parse_logs.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

0 commit comments

Comments
 (0)