Skip to content

Commit 193707d

Browse files
authored
Start Array API infrastructure (#573)
## Description <!-- Provide a brief description of the PR's purpose here. --> Towards #570 ## TODO <!-- Notable points that this PR has either accomplished or will accomplish. --> - [x] Add `array-api-compat` to deps -- no version specified because it seems like a fairly flexible library - [x] Add instructions for installing PyTorch -- it is now marked as experimental in the README - [x] Add CI for testing with PyTorch -- it is added as a separate test in `ci.yaml` that copies the existing tests -- eventually this should be merged into the main tests (#574) - [x] Set up testing for Array API libraries with `xp_and_device` fixture in `tests/conftest.py` - [x] Update testing instructions in `tests/README.md` ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [x] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go
1 parent 2618657 commit 193707d

File tree

6 files changed

+93
-12
lines changed

6 files changed

+93
-12
lines changed

.github/workflows/testing.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,37 @@ jobs:
6161
run: pip install pymoo
6262
- name: Test pymoo extra
6363
run: pytest tests/emitters_pymoo tests/schedulers_pymoo
64+
test_array_api:
65+
strategy:
66+
max-parallel: 12 # All in parallel.
67+
matrix:
68+
os: [ubuntu-latest, macos-latest, windows-latest]
69+
python-version: ["3.12"]
70+
runs-on: ${{ matrix.os }}
71+
steps:
72+
- uses: actions/checkout@v4
73+
- uses: conda-incubator/setup-miniconda@v3
74+
with:
75+
python-version: ${{ matrix.python-version }}
76+
- name: Install core deps
77+
run: pip install .[dev]
78+
- name: Install torch
79+
run: pip install torch
80+
- name: Test core
81+
run: >
82+
pytest tests/archives tests/emitters tests/schedulers
83+
- name: Install visualize dep
84+
run: pip install .[visualize]
85+
- name: Test visualize extra
86+
run: pytest tests/visualize
87+
- name: Install cma
88+
run: pip install cma
89+
- name: Test cma extra
90+
run: pytest tests/emitters_pycma
91+
- name: Install pymoo
92+
run: pip install pymoo
93+
- name: Test pymoo extra
94+
run: pytest tests/emitters_pymoo tests/schedulers_pymoo
6495
visualize_qdax:
6596
runs-on: ubuntu-latest
6697
steps:

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# History
22

3+
## (Forthcoming)
4+
5+
### Changelog
6+
7+
#### API
8+
9+
- Support array backends via Python array API Standard ({pr}`573`)
10+
311
## 0.8.1
412

513
This release makes some minor updates to 0.8.0.

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ in pyribs operates as follows:
9595

9696
## Installation
9797

98-
pyribs supports Python 3.9 and above. The vast majority of users can install
98+
Pyribs supports Python 3.9 and above. The vast majority of users can install
9999
pyribs by running:
100100

101101
```bash
@@ -123,6 +123,12 @@ python -c "import ribs; print(ribs.__version__)"
123123

124124
You should see a version number in the output.
125125

126+
**Experimental:** Pyribs is experimenting with adding support for running QD
127+
algorithms in PyTorch via the
128+
[Python array API standard](https://data-apis.org/array-api/latest/). To enable
129+
this functionality, [install PyTorch](https://pytorch.org), such as with
130+
`pip install torch`.
131+
126132
## Usage
127133

128134
Here we show an example application of CMA-ME in pyribs. To initialize the

pyproject.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,14 @@ classifiers = [
2929
]
3030
requires-python = ">=3.9.0"
3131
dependencies = [
32-
# numpy>=1.17.0 is when default_rng becomes available;
33-
# scikit-learn 1.1.0 requires numpy 1.17.3+
34-
"numpy>=1.17.3",
35-
"numpy_groupies>=0.9.16", # Supports Python 3.7 and up.
32+
"array_api_compat",
3633
"numba>=0.51.0",
34+
"numpy>=1.22.0", # numpy>=1.22.0 is when Array API gains support.
35+
"numpy_groupies>=0.9.16", # Supports Python 3.7 and up.
3736
"pandas>=1.0.0",
37+
"scikit-learn>=1.1.0",
38+
"scipy>=1.7.0",
3839
"sortedcontainers>=2.0.0", # Primarily used in SlidingBoundariesArchive.
39-
"scikit-learn>=1.1.0", # Primarily used in CVTArchive.
40-
"scipy>=1.7.0", # Primarily used in CVTArchive.
4140
"threadpoolctl>=3.0.0",
4241
]
4342

tests/README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@ For information on running tests, see [CONTRIBUTING.md](../CONTRIBUTING.md).
1010
## External Libraries
1111

1212
Some features of pyribs require external libraries that are optional and thus
13-
not specified in the default installation command. We separate these tests into
14-
separate directories:
15-
16-
- `visualize_qdax/` tests visualization of QDax components
17-
- `emitters_pycma/` holds emitter tests that require pycma
13+
not specified in the default installation command. We place these tests into
14+
separate directories, such as `visualize_qdax/` and `emitters_pycma/`.
1815

1916
## Additional Tests
2017

2118
This directory also contains:
2219

2320
- `examples.sh`: checks that the examples work end-to-end
2421
- `tutorials.sh`: checks that the tutorials work end-to-end
22+
23+
## Array API
24+
25+
To write tests for components that feature the Array API, use the
26+
`xp_and_device` fixture to receive a tuple with the array namespace `xp` and the
27+
device `device`. `xp_and_device` is implemented in `tests/conftest.py`.

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Shared functionality for all tests."""
2+
import numpy as np
3+
import pytest
4+
from array_api_compat import array_namespace
5+
6+
# Array API backend handling. Adapted from scipy:
7+
# https://github.com/scipy/scipy/blob/888ca356eda34481e0e32b1be48c1262077d79a7/scipy/conftest.py#L283
8+
xp_available_backends = [
9+
pytest.param(
10+
(np, None), # `None` should default to cpu for numpy.
11+
id="numpy-cpu",
12+
),
13+
]
14+
15+
try:
16+
import torch
17+
18+
xp_available_backends.append(
19+
pytest.param((torch, torch.device("cpu")), id="torch-cpu"))
20+
21+
if torch.cuda.is_available():
22+
xp_available_backends.append(
23+
pytest.param((torch, torch.device("cuda")), id="torch-cuda"))
24+
except ImportError:
25+
pass
26+
27+
28+
@pytest.fixture(params=xp_available_backends)
29+
def xp_and_device(request):
30+
"""Run the test that uses this fixture on each available array API library
31+
and device."""
32+
xp, device = request.param
33+
xp = array_namespace(xp.empty(0))
34+
return xp, device

0 commit comments

Comments
 (0)