Skip to content

Commit 31c1f25

Browse files
committed
Merge branch 'rocm-main' into ci-upstream-sync-110_1
2 parents cf308a8 + 9fbc1c1 commit 31c1f25

File tree

12 files changed

+245
-41
lines changed

12 files changed

+245
-41
lines changed

.github/workflows/ci-build.yaml

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI
1+
name: ROCm CPU CI
22

33
# We test all supported Python versions as follows:
44
# - 3.10 : Documentation build
@@ -11,10 +11,10 @@ on:
1111
# but only for the main branch
1212
push:
1313
branches:
14-
- main
14+
- rocm-main
1515
pull_request:
1616
branches:
17-
- main
17+
- rocm-main
1818

1919
permissions:
2020
contents: read # to fetch code
@@ -42,12 +42,8 @@ jobs:
4242
- run: pre-commit run --show-diff-on-failure --color=always --all-files
4343

4444
build:
45-
# Don't execute in fork due to runner type
46-
if: github.repository == 'jax-ml/jax'
4745
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
48-
runs-on: linux-x86-n2-32
49-
container:
50-
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
46+
runs-on: ROCM-Ubuntu
5147
timeout-minutes: 60
5248
strategy:
5349
matrix:
@@ -65,10 +61,6 @@ jobs:
6561
num_generated_cases: 1
6662
steps:
6763
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
68-
- name: Image Setup
69-
run: |
70-
apt update
71-
apt install -y libssl-dev
7264
- name: Set up Python ${{ matrix.python-version }}
7365
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
7466
with:
@@ -95,12 +87,12 @@ jobs:
9587
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
9688
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
9789
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
98-
pytest -n auto --tb=short --maxfail=20 tests examples
90+
pytest -n 4 --tb=short --maxfail=20 tests examples
9991
10092
10193
documentation:
10294
name: Documentation - test code snippets
103-
runs-on: ubuntu-latest
95+
runs-on: ROCM-Ubuntu
10496
timeout-minutes: 10
10597
strategy:
10698
matrix:
@@ -128,19 +120,13 @@ jobs:
128120
129121
documentation_render:
130122
name: Documentation - render documentation
131-
runs-on: linux-x86-n2-16
132-
container:
133-
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
134-
timeout-minutes: 10
123+
runs-on: ubuntu-latest
124+
timeout-minutes: 20
135125
strategy:
136126
matrix:
137127
python-version: ['3.10']
138128
steps:
139129
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
140-
- name: Image Setup
141-
run: |
142-
apt update
143-
apt install -y libssl-dev libsqlite3-dev
144130
- name: Set up Python ${{ matrix.python-version }}
145131
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
146132
with:
@@ -193,9 +179,7 @@ jobs:
193179
194180
ffi:
195181
name: FFI example
196-
runs-on: linux-x86-g2-16-l4-1gpu
197-
container:
198-
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
182+
runs-on: ROCM-Ubuntu
199183
timeout-minutes: 30
200184
steps:
201185
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -206,7 +190,7 @@ jobs:
206190
- name: Install JAX
207191
run: |
208192
pip install uv
209-
uv pip install --system .[cuda12]
193+
uv pip install --system .
210194
- name: Build and install example project
211195
run: uv pip install --system ./examples/ffi[test]
212196
env:
@@ -215,10 +199,11 @@ jobs:
215199
# a different toolchain. GCC is the default compiler on the
216200
# 'ubuntu-latest' runner, but we still set this explicitly just to be
217201
# clear.
218-
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
202+
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
219203
- name: Run CPU tests
220204
run: python -m pytest examples/ffi/tests
221205
env:
222206
JAX_PLATFORM_NAME: cpu
223207
- name: Run GPU tests
224208
run: python -m pytest examples/ffi/tests
209+

.github/workflows/rocm-ci.yml

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1-
name: ROCm GPU Post-Merge Check
1+
name: ROCm GPU CI
22

33
on:
4-
# Trigger the workflow after a push into the main branch
4+
# Trigger the workflow on push or pull request,
5+
# but only for the rocm-main branch
56
push:
67
branches:
7-
- main
8-
9-
permissions:
10-
contents: read
8+
- rocm-main
9+
pull_request:
10+
branches:
11+
- rocm-main
1112

1213
concurrency:
1314
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
15+
cancel-in-progress: true
1416

1517
jobs:
16-
build-jax-in-docker:
17-
runs-on: linux-x86_64-cirrascale-64-8gpu-amd-mi250
18+
build-jax-in-docker: # strategy and matrix come here
19+
runs-on: mi-250
1820
env:
1921
BASE_IMAGE: "ubuntu:22.04"
20-
TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
22+
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
2123
PYTHON_VERSION: "3.10"
2224
ROCM_VERSION: "6.2.4"
2325
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
@@ -32,6 +34,9 @@ jobs:
3234
ls
3335
- name: Print system info
3436
run: |
37+
whoami
38+
printenv
39+
df -h
3540
rocm-smi
3641
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
3742
with:
@@ -50,10 +55,9 @@ jobs:
5055
uses: actions/upload-artifact@v4
5156
with:
5257
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
53-
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl
54-
retention-days: 2
58+
path: ./dist/*.whl
5559
- name: Run tests
5660
run: |
5761
cd $WORKSPACE_DIR
58-
python3 build/rocm/ci_build test $TEST_IMAGE
62+
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"
5963
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Pulls the latest changes from upstream into main and opens a PR to merge
2+
# them into rocm-main branch.
3+
4+
name: ROCm Nightly Upstream Sync
5+
on:
6+
workflow_dispatch:
7+
schedule:
8+
- cron: '0 6 * * 1-5'
9+
permissions:
10+
contents: write
11+
pull-requests: write
12+
env:
13+
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
14+
jobs:
15+
sync-main:
16+
runs-on: ubuntu-latest
17+
steps:
18+
- run: |
19+
gh auth status
20+
gh repo sync rocm/jax -b main
21+
env:
22+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
23+
create-sync-branch:
24+
needs: sync-main
25+
runs-on: ubuntu-latest
26+
env:
27+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
28+
steps:
29+
- name: Checkout code
30+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
31+
- name: Create branch
32+
run: |
33+
git fetch
34+
git checkout origin/main
35+
git checkout -b $SYNC_BRANCH_NAME
36+
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
37+
git config --global user.email "[email protected]"
38+
git config --global user.name "GitHub Actions"
39+
git merge origin/rocm-main || true
40+
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
41+
git merge --abort || true
42+
git push origin HEAD
43+
open-sync-pr:
44+
needs: create-sync-branch
45+
runs-on: ubuntu-latest
46+
steps:
47+
- run: |
48+
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
49+
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
50+
env:
51+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
52+
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: ROCm Open Upstream PR
2+
on:
3+
pull_request:
4+
types: [ labeled ]
5+
branches: [ rocm-main ]
6+
jobs:
7+
open-upstream:
8+
if: ${{ github.event.label.name == 'open-upstream' }}
9+
permissions:
10+
contents: write
11+
pull-requests: write
12+
runs-on: ubuntu-latest
13+
env:
14+
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
15+
steps:
16+
- name: Checkout code
17+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
18+
- name: Rebase code to main
19+
run: |
20+
git config --global user.email "[email protected]"
21+
git config --global user.name "Github Actions"
22+
git fetch
23+
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
24+
git rebase --onto origin/main origin/rocm-main
25+
# Force push here so that we don't run into conflicts with the origin branch
26+
git push origin HEAD --force
27+
- name: Leave link to create PR
28+
env:
29+
GH_TOKEN: ${{ github.token }}
30+
run: |
31+
# Bash is not friendly with newline characters, so make our own
32+
NL=$'\n'
33+
# Encode the PR title and body for passing as URL get parameters
34+
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
35+
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
36+
# Create a link to the that will open up a new PR form to upstream and autofill the fields
37+
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
38+
# Add a comment with the link to the PR
39+
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
40+
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"
41+

.github/workflows/upstream-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222

2323
jobs:
2424
upstream-dev:
25-
runs-on: ubuntu-latest
25+
runs-on: ROCM-Ubuntu
2626
permissions:
2727
contents: read
2828
issues: write # for failed-build-issue

build/rocm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
207207
### Simplified Build Script
208208

209209
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
210+

build/rocm/ci_build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def parse_args():
307307
"--test-cmd",
308308
default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh",
309309
)
310+
testp.add_argument("--test-cmd", default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh")
310311

311312
ddp = subp.add_parser("dist_docker")
312313
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")

build/rocm/upload_wheels.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
3+
# Check for user-supplied arguments.
4+
if [[ $# -lt 2 ]]; then
5+
echo "Usage: $0 <jax_home_directory> <version>"
6+
exit 1
7+
fi
8+
9+
# Set JAX_HOME and RELEASE_VERSION from user arguments.
10+
JAX_HOME=$1
11+
RELEASE_VERSION=$2
12+
WHEELHOUSE="$JAX_HOME/wheelhouse"
13+
14+
# Projects to upload separately to PyPI.
15+
PROJECTS=("jax_rocm60_pjrt" "jax_rocm60_plugin")
16+
17+
# PyPI API Token.
18+
PYPI_API_TOKEN=${PYPI_API_TOKEN:-"pypi-replace_with_token"}
19+
20+
# Ensure the specified JAX_HOME and wheelhouse directories exists.
21+
if [[ ! -d "$JAX_HOME" ]]; then
22+
echo "Error: The specified JAX_HOME directory does not exist: $JAX_HOME"
23+
exit 1
24+
fi
25+
if [[ ! -d "$WHEELHOUSE" ]]; then
26+
echo "Error: The wheelhouse directory does not exist: $WHEELHOUSE"
27+
exit 1
28+
fi
29+
30+
upload_and_release_project() {
31+
local project=$1
32+
33+
echo "Searching for wheels matching project: $project version: $RELEASE_VERSION..."
34+
wheels=($(ls $WHEELHOUSE | grep "^${project}-${RELEASE_VERSION}[.-].*\.whl"))
35+
if [[ ${#wheels[@]} -eq 0 ]]; then
36+
echo "No wheels found for project: $project version: $RELEASE_VERSION. Skipping..."
37+
return
38+
fi
39+
echo "Found wheels for $project: ${wheels[*]}"
40+
41+
echo "Uploading wheels for $project version $RELEASE_VERSION to PyPI..."
42+
for wheel in "${wheels[@]}"; do
43+
twine upload --verbose --repository pypi --non-interactive --username "__token__" --password "$PYPI_API_TOKEN" "$WHEELHOUSE/$wheel"
44+
done
45+
}
46+
47+
# Install twine if not already installed.
48+
python -m pip install --upgrade twine
49+
50+
# Iterate over each project and upload its wheels.
51+
for project in "${PROJECTS[@]}"; do
52+
upload_and_release_project $project
53+
done

0 commit comments

Comments
 (0)