1- name : CI
1+ name : ROCm CPU CI
22
33# We test all supported Python versions as follows:
44# - 3.10 : Documentation build
1111 # but only for the main branch
1212 push :
1313 branches :
14- - main
14+ - rocm- main
1515 pull_request :
1616 branches :
17- - main
17+ - rocm- main
1818
1919permissions :
2020 contents : read # to fetch code
4242 - run : pre-commit run --show-diff-on-failure --color=always --all-files
4343
4444 build :
45- # Don't execute in fork due to runner type
46- if : github.repository == 'jax-ml/jax'
4745 name : " build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
48- runs-on : linux-x86-n2-32
49- container :
50- image : index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
46+ runs-on : ROCM-Ubuntu
5147 timeout-minutes : 60
5248 strategy :
5349 matrix :
6561 num_generated_cases : 1
6662 steps :
6763 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
68- - name : Image Setup
69- run : |
70- apt update
71- apt install -y libssl-dev
7264 - name : Set up Python ${{ matrix.python-version }}
7365 uses : actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
7466 with :
@@ -95,12 +87,12 @@ jobs:
9587 echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
9688 echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
9789 echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
98- pytest -n auto --tb=short --maxfail=20 tests examples
90+ pytest -n 4 --tb=short --maxfail=20 tests examples
9991
10092
10193 documentation :
10294 name : Documentation - test code snippets
103- runs-on : ubuntu-latest
95+ runs-on : ROCM-Ubuntu
10496 timeout-minutes : 10
10597 strategy :
10698 matrix :
@@ -128,19 +120,13 @@ jobs:
128120
129121 documentation_render :
130122 name : Documentation - render documentation
131- runs-on : linux-x86-n2-16
132- container :
133- image : index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
134- timeout-minutes : 10
123+ runs-on : ubuntu-latest
124+ timeout-minutes : 20
135125 strategy :
136126 matrix :
137127 python-version : ['3.10']
138128 steps :
139129 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
140- - name : Image Setup
141- run : |
142- apt update
143- apt install -y libssl-dev libsqlite3-dev
144130 - name : Set up Python ${{ matrix.python-version }}
145131 uses : actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
146132 with :
@@ -194,9 +180,7 @@ jobs:
194180
195181 ffi :
196182 name : FFI example
197- runs-on : linux-x86-g2-16-l4-1gpu
198- container :
199- image : index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
183+ runs-on : ROCM-Ubuntu
200184 timeout-minutes : 30
201185 steps :
202186 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -207,7 +191,8 @@ jobs:
207191 - name : Install JAX
208192 run : |
209193 pip install uv~=0.5.30
210- uv pip install --system .[cuda12]
194+ pip install uv
195+ uv pip install --system .
211196 - name : Build and install example project
212197 run : uv pip install --system ./examples/ffi[test]
213198 env :
@@ -216,10 +201,11 @@ jobs:
216201 # a different toolchain. GCC is the default compiler on the
217202 # 'ubuntu-latest' runner, but we still set this explicitly just to be
218203 # clear.
219- CMAKE_ARGS : -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
204+ CMAKE_ARGS : -DCMAKE_CXX_COMPILER=g++ # -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
220205 - name : Run CPU tests
221206 run : python -m pytest examples/ffi/tests
222207 env :
223208 JAX_PLATFORM_NAME : cpu
224209 - name : Run GPU tests
225210 run : python -m pytest examples/ffi/tests
211+
0 commit comments