1- name : CI
1+ name : ROCm CPU CI
22
33# We test all supported Python versions as follows:
44# - 3.10 : Documentation build
1111 # but only for the main branch
1212 push :
1313 branches :
14- - main
14+ - rocm- main
1515 pull_request :
1616 branches :
17- - main
17+ - rocm- main
1818
1919permissions :
2020 contents : read # to fetch code
4242 - run : pre-commit run --show-diff-on-failure --color=always --all-files
4343
4444 build :
45- # Don't execute in fork due to runner type
46- if : github.repository == 'jax-ml/jax'
4745 name : " build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
48- runs-on : linux-x86-n2-32
49- container :
50- image : index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
46+ runs-on : ROCM-Ubuntu
5147 timeout-minutes : 60
5248 strategy :
5349 matrix :
6561 num_generated_cases : 1
6662 steps :
6763 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
68- - name : Image Setup
69- run : |
70- apt update
71- apt install -y libssl-dev
7264 - name : Set up Python ${{ matrix.python-version }}
7365 uses : actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
7466 with :
@@ -95,12 +87,12 @@ jobs:
9587 echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
9688 echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
9789 echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
98- pytest -n auto --tb=short --maxfail=20 tests examples
90+ pytest -n 4 --tb=short --maxfail=20 tests examples
9991
10092
10193 documentation :
10294 name : Documentation - test code snippets
103- runs-on : ubuntu-latest
95+ runs-on : ROCM-Ubuntu
10496 timeout-minutes : 10
10597 strategy :
10698 matrix :
@@ -128,19 +120,13 @@ jobs:
128120
129121 documentation_render :
130122 name : Documentation - render documentation
131- runs-on : linux-x86-n2-16
132- container :
133- image : index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
134- timeout-minutes : 10
123+ runs-on : ubuntu-latest
124+ timeout-minutes : 20
135125 strategy :
136126 matrix :
137127 python-version : ['3.10']
138128 steps :
139129 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
140- - name : Image Setup
141- run : |
142- apt update
143- apt install -y libssl-dev libsqlite3-dev
144130 - name : Set up Python ${{ matrix.python-version }}
145131 uses : actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
146132 with :
@@ -193,9 +179,7 @@ jobs:
193179
194180 ffi :
195181 name : FFI example
196- runs-on : linux-x86-g2-16-l4-1gpu
197- container :
198- image : index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
182+ runs-on : ROCM-Ubuntu
199183 timeout-minutes : 30
200184 steps :
201185 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -206,7 +190,7 @@ jobs:
206190 - name : Install JAX
207191 run : |
208192 pip install uv
209- uv pip install --system .[cuda12]
193+ uv pip install --system .
210194 - name : Build and install example project
211195 run : uv pip install --system ./examples/ffi[test]
212196 env :
@@ -215,10 +199,11 @@ jobs:
215199 # a different toolchain. GCC is the default compiler on the
216200 # 'ubuntu-latest' runner, but we still set this explicitly just to be
217201 # clear.
218- CMAKE_ARGS : -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
202+ CMAKE_ARGS : -DCMAKE_CXX_COMPILER=g++ # -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
219203 - name : Run CPU tests
220204 run : python -m pytest examples/ffi/tests
221205 env :
222206 JAX_PLATFORM_NAME : cpu
223207 - name : Run GPU tests
224208 run : python -m pytest examples/ffi/tests
209+
0 commit comments