Skip to content

Commit 73f3518

Browse files
authored
Port PR CI workflow from rocm-main (#312)
1 parent 0074fed commit 73f3518

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

.github/workflows/rocm-ci.yml

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
name: ROCm GPU CI
2+
3+
on:
4+
# Trigger the workflow on push or pull request,
5+
# but only for the rocm-main branch
6+
push:
7+
branches:
8+
- rocm-main
9+
- 'rocm-jaxlib-v*'
10+
pull_request:
11+
branches:
12+
- rocm-main
13+
- 'rocm-jaxlib-v*'
14+
15+
concurrency:
16+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
17+
cancel-in-progress: true
18+
19+
jobs:
20+
build-jax-in-docker:
21+
runs-on: mi-250
22+
strategy:
23+
matrix:
24+
python: ["3.10", "3.11", "3.12"]
25+
rocm: [ "6.3.3"]
26+
env:
27+
BASE_IMAGE: "ubuntu:22.04"
28+
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
29+
PYTHON_VERSION: ${{ matrix.python }}
30+
ROCM_VERSION: ${{ matrix.rocm }}
31+
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
32+
steps:
33+
- name: Clean up old runs
34+
run: |
35+
ls
36+
# Make sure that we own all of the files so that we have permissions to delete them
37+
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
38+
# Remove any old work directories from this machine
39+
rm -rf workdir_*
40+
ls
41+
- name: Print system info
42+
run: |
43+
whoami
44+
printenv
45+
df -h
46+
rocm-smi
47+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
48+
with:
49+
path: ${{ env.WORKSPACE_DIR }}
50+
- name: Build JAX
51+
run: |
52+
pushd $WORKSPACE_DIR
53+
python3 build/rocm/ci_build \
54+
--rocm-version $ROCM_VERSION \
55+
--base-docker $BASE_IMAGE \
56+
--python-versions $PYTHON_VERSION \
57+
--compiler=clang \
58+
dist_docker \
59+
--image-tag $TEST_IMAGE
60+
- name: Archive jax wheels
61+
uses: actions/upload-artifact@v4
62+
with:
63+
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
64+
path: ./dist/*.whl
65+
- name: Run tests
66+
env:
67+
GPU_COUNT: "8"
68+
GFX: "gfx90a"
69+
run: |
70+
cd $WORKSPACE_DIR
71+
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"
72+

build/rocm/ci_build

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def dist_docker(
193193
subprocess.check_call(cmd)
194194

195195

196-
def test(image_name):
196+
def test(image_name, test_cmd):
197197
"""Run unit tests like CI would inside a JAX image."""
198198

199199
gpu_args = [
@@ -211,10 +211,12 @@ def test(image_name):
211211
cmd = [
212212
"docker",
213213
"run",
214-
"-it",
215214
"--rm",
216215
]
217216

217+
if os.isatty(sys.stdout.fileno()):
218+
cmd.append("-it")
219+
218220
# NOTE(mrodden): we need jax source dir for the unit test code only,
219221
# JAX and jaxlib are already installed from wheels
220222
mounts = [
@@ -225,7 +227,7 @@ def test(image_name):
225227
cmd.extend(mounts)
226228
cmd.extend(gpu_args)
227229

228-
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh"
230+
container_cmd = "cd /jax && " + test_cmd
229231
cmd.append(image_name)
230232
cmd.extend(
231233
[
@@ -288,6 +290,7 @@ def parse_args():
288290

289291
testp = subp.add_parser("test")
290292
testp.add_argument("image_name")
293+
testp.add_argument("--test-cmd", default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh")
291294

292295
ddp = subp.add_parser("dist_docker")
293296
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
@@ -311,7 +314,7 @@ def main():
311314
)
312315

313316
elif args.action == "test":
314-
test(args.image_name)
317+
test(args.image_name, args.test_cmd)
315318

316319
elif args.action == "dist_docker":
317320
dist_wheels(

0 commit comments

Comments
 (0)