Skip to content

Commit c7b407c

Browse files
committed
Merge branch 'rocm-main' into ci-upstream-sync-151_1
2 parents 1e36cbe + d864b4f commit c7b407c

23 files changed

+391
-140
lines changed

.bazelrc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
130130
build:clang --copt=-Qunused-arguments
131131
# Error on struct/class mismatches, since this causes link failures on Windows.
132132
build:clang --copt=-Werror=mismatched-tags
133+
# Don't error out on C++23 extensions. Needed for building the clang-19.
134+
build:clang --copt=-Wno-error=c23-extensions
133135

134136
# Configs for CUDA
135137
build:cuda --repo_env TF_NEED_CUDA=1

.github/CODEOWNERS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Require approvals from someone on the JAX team before PRs are merged
2+
* @ROCm/jax-devs
3+

.github/workflows/ci-build.yaml

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI
1+
name: ROCm CPU CI
22

33
# We test all supported Python versions as follows:
44
# - 3.10 : Documentation build
@@ -11,10 +11,10 @@ on:
1111
# but only for the main branch
1212
push:
1313
branches:
14-
- main
14+
- rocm-main
1515
pull_request:
1616
branches:
17-
- main
17+
- rocm-main
1818

1919
permissions:
2020
contents: read # to fetch code
@@ -42,12 +42,8 @@ jobs:
4242
- run: pre-commit run --show-diff-on-failure --color=always --all-files
4343

4444
build:
45-
# Don't execute in fork due to runner type
46-
if: github.repository == 'jax-ml/jax'
4745
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
48-
runs-on: linux-x86-n2-32
49-
container:
50-
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
46+
runs-on: ROCM-Ubuntu
5147
timeout-minutes: 60
5248
strategy:
5349
matrix:
@@ -65,10 +61,6 @@ jobs:
6561
num_generated_cases: 1
6662
steps:
6763
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
68-
- name: Image Setup
69-
run: |
70-
apt update
71-
apt install -y libssl-dev
7264
- name: Set up Python ${{ matrix.python-version }}
7365
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
7466
with:
@@ -95,12 +87,12 @@ jobs:
9587
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
9688
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
9789
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
98-
pytest -n auto --tb=short --maxfail=20 tests examples
90+
pytest -n 4 --tb=short --maxfail=20 tests examples
9991
10092
10193
documentation:
10294
name: Documentation - test code snippets
103-
runs-on: ubuntu-latest
95+
runs-on: ROCM-Ubuntu
10496
timeout-minutes: 10
10597
strategy:
10698
matrix:
@@ -128,19 +120,13 @@ jobs:
128120
129121
documentation_render:
130122
name: Documentation - render documentation
131-
runs-on: linux-x86-n2-16
132-
container:
133-
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
134-
timeout-minutes: 10
123+
runs-on: ubuntu-latest
124+
timeout-minutes: 20
135125
strategy:
136126
matrix:
137127
python-version: ['3.10']
138128
steps:
139129
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
140-
- name: Image Setup
141-
run: |
142-
apt update
143-
apt install -y libssl-dev libsqlite3-dev
144130
- name: Set up Python ${{ matrix.python-version }}
145131
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
146132
with:
@@ -194,9 +180,7 @@ jobs:
194180
195181
ffi:
196182
name: FFI example
197-
runs-on: linux-x86-g2-16-l4-1gpu
198-
container:
199-
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
183+
runs-on: ROCM-Ubuntu
200184
timeout-minutes: 30
201185
steps:
202186
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -207,7 +191,8 @@ jobs:
207191
- name: Install JAX
208192
run: |
209193
pip install uv~=0.5.30
210-
uv pip install --system .[cuda12]
194+
pip install uv
195+
uv pip install --system .
211196
- name: Build and install example project
212197
run: uv pip install --system ./examples/ffi[test]
213198
env:
@@ -216,10 +201,11 @@ jobs:
216201
# a different toolchain. GCC is the default compiler on the
217202
# 'ubuntu-latest' runner, but we still set this explicitly just to be
218203
# clear.
219-
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
204+
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
220205
- name: Run CPU tests
221206
run: python -m pytest examples/ffi/tests
222207
env:
223208
JAX_PLATFORM_NAME: cpu
224209
- name: Run GPU tests
225210
run: python -m pytest examples/ffi/tests
211+

.github/workflows/rocm-ci.yml

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1-
name: ROCm GPU Post-Merge Check
1+
name: ROCm GPU CI
22

33
on:
4-
# Trigger the workflow after a push into the main branch
4+
# Trigger the workflow on push or pull request,
5+
# but only for the rocm-main branch
56
push:
67
branches:
7-
- main
8-
9-
permissions:
10-
contents: read
8+
- rocm-main
9+
- 'rocm-jaxlib-v*'
10+
pull_request:
11+
branches:
12+
- rocm-main
13+
- 'rocm-jaxlib-v*'
1114

1215
concurrency:
1316
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
17+
cancel-in-progress: true
1418

1519
jobs:
16-
build-jax-in-docker:
17-
runs-on: linux-x86_64-cirrascale-64-8gpu-amd-mi250
20+
build-jax-in-docker: # strategy and matrix come here
21+
runs-on: mi-250
1822
env:
1923
BASE_IMAGE: "ubuntu:22.04"
20-
TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
24+
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
2125
PYTHON_VERSION: "3.10"
2226
ROCM_VERSION: "6.2.4"
2327
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
@@ -32,6 +36,9 @@ jobs:
3236
ls
3337
- name: Print system info
3438
run: |
39+
whoami
40+
printenv
41+
df -h
3542
rocm-smi
3643
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
3744
with:
@@ -50,9 +57,12 @@ jobs:
5057
uses: actions/upload-artifact@v4
5158
with:
5259
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
53-
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl
54-
retention-days: 2
60+
path: ./dist/*.whl
5561
- name: Run tests
62+
env:
63+
ROCM_TEST_INCLUDE_SKIPS: "1"
64+
GPU_COUNT: "8"
65+
GFX: "gfx90a"
5666
run: |
5767
cd $WORKSPACE_DIR
5868
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Pulls the latest changes from upstream into main and opens a PR to merge
2+
# them into rocm-main branch.
3+
4+
name: ROCm Nightly Upstream Sync
5+
on:
6+
workflow_dispatch:
7+
schedule:
8+
- cron: '0 6 * * 1-5'
9+
permissions:
10+
contents: write
11+
pull-requests: write
12+
env:
13+
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
14+
jobs:
15+
sync-main:
16+
runs-on: ubuntu-latest
17+
steps:
18+
- name: Generate an app token
19+
id: generate-token
20+
uses: actions/create-github-app-token@v1
21+
with:
22+
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
23+
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
24+
- name: Sync our main with upstream main
25+
run: |
26+
gh auth status
27+
gh repo sync rocm/jax -b main
28+
env:
29+
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
30+
create-sync-branch:
31+
needs: sync-main
32+
runs-on: ubuntu-latest
33+
env:
34+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
35+
steps:
36+
- name: Checkout code
37+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
38+
- name: Create branch
39+
run: |
40+
git fetch
41+
git checkout origin/main
42+
git checkout -b $SYNC_BRANCH_NAME
43+
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
44+
git config --global user.email "[email protected]"
45+
git config --global user.name "GitHub Actions"
46+
git merge origin/rocm-main || true
47+
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
48+
git merge --abort || true
49+
git push origin HEAD
50+
open-sync-pr:
51+
needs: create-sync-branch
52+
runs-on: ubuntu-latest
53+
steps:
54+
- name: Generate an app token
55+
id: generate-token
56+
uses: actions/create-github-app-token@v1
57+
with:
58+
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
59+
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
60+
- name: Open a PR to rocm-main
61+
run: |
62+
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
63+
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
64+
env:
65+
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
66+
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: ROCm Open Upstream PR
2+
on:
3+
pull_request:
4+
types: [ labeled ]
5+
branches: [ rocm-main ]
6+
jobs:
7+
open-upstream:
8+
if: ${{ github.event.label.name == 'open-upstream' }}
9+
permissions:
10+
contents: write
11+
pull-requests: write
12+
runs-on: ubuntu-latest
13+
env:
14+
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
15+
steps:
16+
- name: Checkout code
17+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
18+
- name: Rebase code to main
19+
run: |
20+
git config --global user.email "[email protected]"
21+
git config --global user.name "Github Actions"
22+
git fetch
23+
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
24+
git rebase --onto origin/main origin/rocm-main
25+
# Force push here so that we don't run into conflicts with the origin branch
26+
git push origin HEAD --force
27+
- name: Leave link to create PR
28+
env:
29+
GH_TOKEN: ${{ github.token }}
30+
run: |
31+
# Bash is not friendly with newline characters, so make our own
32+
NL=$'\n'
33+
# Encode the PR title and body for passing as URL get parameters
34+
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
35+
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
36+
# Create a link to the that will open up a new PR form to upstream and autofill the fields
37+
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
38+
# Add a comment with the link to the PR
39+
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
40+
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"
41+

.github/workflows/upstream-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222

2323
jobs:
2424
upstream-dev:
25-
runs-on: ubuntu-latest
25+
runs-on: ROCM-Ubuntu
2626
permissions:
2727
contents: read
2828
issues: write # for failed-build-issue

build/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
262262
rocm_group.add_argument(
263263
"--rocm_amdgpu_targets",
264264
type=str,
265-
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201",
265+
default="gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201",
266266
help="A comma-separated list of ROCm amdgpu targets to support.",
267267
)
268268

build/rocm/Dockerfile.ms

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y \
1313
&& rm -rf /var/lib/apt/lists/*
1414
1515
# Add target file to help determine which device(s) to build for
16-
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
16+
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
1717
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
1818
1919
# Install ROCm
@@ -62,7 +62,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
6262
pytest-reportlog \
6363
pytest-rerunfailures \
6464
pytest-json-report \
65-
pytest-csv \
6665
cloudpickle \
6766
portpicker \
6867
matplotlib \

build/rocm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
207207
### Simplified Build Script
208208

209209
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
210+

0 commit comments

Comments
 (0)