Skip to content

Commit c4f9ebe

Browse files
authored
Support Array API in ArrayStore (#571)
## Description <!-- Provide a brief description of the PR's purpose here. --> Towards #570 Passing `xp` and `device` into the archive and ArrayStore was suggested here: data-apis/array-api-compat#342 and is based on what is done in scipy here: * https://github.com/scipy/scipy/blob/v1.16.0/scipy/signal/windows/_windows.py#L953-L1009 * https://github.com/scipy/scipy/blob/4d3dcc103612a2edaec7069638b7f8d0d75cab8b/scipy/signal/windows/_windows.py#L44-L50 ## TODO <!-- Notable points that this PR has either accomplished or will accomplish. --> - [x] Update ArrayStore - [x] Add xp arg - [x] Add device arg - [x] Switch `updates` prop to be a simple list (so that we don't have to deal with arrays; it's only two elements) - [x] Remove call to `aggregate` from numpy-groupies - [x] Update pandas return type - [x] Review docstrings - [x] Handle dtypes -- how do we account for the dtypes in calling classes like GridArchive? -> See #577 - [x] Test ArrayStore with PyTorch - [x] Add xp_and_device fixture to conftest - [x] Add torch backend to conftest - [x] Update remaining tests - [x] Test on GPU ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [x] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go
1 parent f3cdec0 commit c4f9ebe

File tree

4 files changed

+493
-163
lines changed

4 files changed

+493
-163
lines changed

HISTORY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#### API
88

9-
- Support array backends via Python array API Standard ({pr}`573`)
9+
- Support array backends via Python array API Standard ({pr}`573`, {pr}`571`)
1010
- **Backwards-incompatible:** Remove raw_dict methods from ArrayStore
1111
({pr}`575`)
1212

ribs/_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Miscellaneous internal utilities."""
22
import numbers
33

4-
import numpy as np
4+
import array_api_compat.numpy as np_compat
5+
import numpy as np # TODO (#576): Remove import of np
6+
from array_api_compat import array_namespace
57

68

79
def check_finite(x, name):
@@ -208,7 +210,33 @@ def validate_single(archive, data, none_objective_ok=False):
208210
return data
209211

210212

213+
# TODO (#576): Replace all calls to readonly with arr_readonly below.
211214
def readonly(arr):
212215
"""Sets an array to be readonly."""
213216
arr.flags.writeable = False
214217
return arr
218+
219+
220+
def arr_readonly(arr):
221+
"""Sets an array to be readonly if possible. Inteded to support arrays
222+
across libraries; currently only supports numpy."""
223+
if isinstance(arr, np_compat.ndarray):
224+
readonly_arr = arr.view()
225+
readonly_arr.flags.writeable = False
226+
return readonly_arr
227+
else:
228+
return arr
229+
230+
231+
def xp_namespace(xp):
232+
"""Utility for retrieving a namespace compatible with the array API.
233+
234+
Expects to receive an argument like `torch` or `numpy`.
235+
236+
Adapted from scipy:
237+
https://github.com/scipy/scipy/blob/4d3dcc103612a2edaec7069638b7f8d0d75cab8b/scipy/signal/windows/_windows.py#L44-L50
238+
239+
For more context, see:
240+
https://github.com/data-apis/array-api-compat/issues/342
241+
"""
242+
return np_compat if xp is None else array_namespace(xp.empty(0))

ribs/archives/_array_store.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from enum import IntEnum
55
from functools import cached_property
66

7-
import numpy as np
8-
from numpy_groupies import aggregate_nb as aggregate
7+
from array_api_compat import is_numpy_array, is_numpy_namespace, is_torch_array
98

10-
from ribs._utils import readonly
9+
from ribs._utils import arr_readonly, xp_namespace
1110
from ribs.archives._archive_data_frame import ArchiveDataFrame
1211

1312

@@ -36,7 +35,7 @@ def __next__(self):
3635
3736
Raises RuntimeError if the store was modified.
3837
"""
39-
if not np.all(self.state == self.store._props["updates"]):
38+
if self.state != self.store._props["updates"]:
4039
# This check should go before the StopIteration check because a call
4140
# to clear() would cause the len(self.store) to be 0 and thus
4241
# trigger StopIteration.
@@ -61,8 +60,8 @@ class ArrayStore:
6160
"""Maintains a set of arrays that share a common dimension.
6261
6362
The ArrayStore consists of several *fields* of data that are manipulated
64-
simultaneously via batch operations. Each field is a NumPy array with a
65-
dimension of ``(capacity, ...)`` and can be of any type.
63+
simultaneously via batch operations. Each field is an array with a dimension
64+
of ``(capacity, ...)`` and can be of any type.
6665
6766
Since the arrays all share a common first dimension, they also share a
6867
common index. For instance, if we :meth:`retrieve` the data at indices ``[0,
@@ -77,6 +76,12 @@ class ArrayStore:
7776
The ArrayStore supports several further operations, such as an :meth:`add`
7877
method that inserts data into the ArrayStore.
7978
79+
By default, the arrays in the ArrayStore are NumPy arrays. However, through
80+
support for the `Python array API standard
81+
<https://data-apis.org/array-api/latest/>`_, it is possible to use arrays
82+
from other libraries like PyTorch by passing in arguments for ``xp`` and
83+
``device``.
84+
8085
Args:
8186
field_desc (dict): Description of fields in the array store. The
8287
description is a dict mapping from a str to a tuple of ``(shape,
@@ -86,6 +91,10 @@ class ArrayStore:
8691
``(capacity, 10)``. Note that field names must be valid Python
8792
identifiers.
8893
capacity (int): Total possible entries in the store.
94+
xp (array_namespace): Optional array namespace. Should be compatible
95+
with the array API standard, or supported by array-api-compat.
96+
Defaults to ``numpy``.
97+
device (device): Device for arrays.
8998
9099
Attributes:
91100
_props (dict): Properties that are common to every ArrayStore.
@@ -97,7 +106,7 @@ class ArrayStore:
97106
* "occupied_list": Array of size ``(capacity,)`` listing all
98107
occupied indices in the store. Only the first ``n_occupied``
99108
elements will be valid.
100-
* "updates": Int array recording number of calls to functions that
109+
* "updates": Int list recording number of calls to functions that
101110
modified the store.
102111
103112
_fields (dict): Holds all the arrays with their data.
@@ -109,13 +118,22 @@ class ArrayStore:
109118
valid Python identifier.
110119
"""
111120

112-
def __init__(self, field_desc, capacity):
121+
def __init__(self, field_desc, capacity, xp=None, device=None):
122+
self._xp = xp_namespace(xp)
123+
self._device = device
124+
113125
self._props = {
114-
"capacity": capacity,
115-
"occupied": np.zeros(capacity, dtype=bool),
116-
"n_occupied": 0,
117-
"occupied_list": np.empty(capacity, dtype=np.int32),
118-
"updates": np.array([0, 0]),
126+
"capacity":
127+
capacity,
128+
"occupied":
129+
self._xp.zeros(capacity, dtype=bool, device=self._device),
130+
"n_occupied":
131+
0,
132+
"occupied_list":
133+
self._xp.empty(capacity,
134+
dtype=self._xp.int32,
135+
device=self._device),
136+
"updates": [0, 0],
119137
}
120138

121139
self._fields = {}
@@ -130,7 +148,9 @@ def __init__(self, field_desc, capacity):
130148
field_shape = (field_shape,)
131149

132150
array_shape = (capacity,) + tuple(field_shape)
133-
self._fields[name] = np.empty(array_shape, dtype)
151+
self._fields[name] = self._xp.empty(array_shape,
152+
dtype=dtype,
153+
device=self._device)
134154

135155
def __len__(self):
136156
"""Number of occupied indices in the store, i.e., number of indices that
@@ -163,15 +183,14 @@ def capacity(self):
163183

164184
@property
165185
def occupied(self):
166-
"""numpy.ndarray: Boolean array of size ``(capacity,)`` indicating
167-
whether each index has a data entry."""
168-
return readonly(self._props["occupied"].view())
186+
"""array: Boolean array of size ``(capacity,)`` indicating whether each
187+
index has a data entry."""
188+
return arr_readonly(self._props["occupied"])
169189

170190
@property
171191
def occupied_list(self):
172-
"""numpy.ndarray: int32 array listing all occupied indices in the
173-
store."""
174-
return readonly(
192+
"""array: int32 array listing all occupied indices in the store."""
193+
return arr_readonly(
175194
self._props["occupied_list"][:self._props["n_occupied"]])
176195

177196
@cached_property
@@ -211,10 +230,14 @@ def dtypes(self):
211230
"measures": np.float32,
212231
}
213232
"""
214-
# Calling `.type` retrieves the numpy scalar type, which is callable:
215-
# - https://numpy.org/doc/stable/reference/arrays.scalars.html
216-
# - https://numpy.org/doc/stable/reference/arrays.dtypes.html
217-
return {name: arr.dtype.type for name, arr in self._fields.items()}
233+
if is_numpy_namespace(self._xp):
234+
# TODO (#577): In NumPy, we currently want the scalar type (i.e.,
235+
# arr.dtype.type rather than arr.dtype), which is callable.
236+
# Ultimately, this should be switched to just be the dtype to be
237+
# compatible across array libraries.
238+
return {name: arr.dtype.type for name, arr in self._fields.items()}
239+
else:
240+
return {name: arr.dtype for name, arr in self._fields.items()}
218241

219242
@cached_property
220243
def dtypes_with_index(self):
@@ -230,7 +253,7 @@ def dtypes_with_index(self):
230253
"index": np.int32,
231254
}
232255
"""
233-
return self.dtypes | {"index": np.int32}
256+
return self.dtypes | {"index": self._xp.int32}
234257

235258
@cached_property
236259
def field_list(self):
@@ -261,15 +284,29 @@ def field_list_with_index(self):
261284
"""
262285
return list(self._fields) + ["index"]
263286

287+
@staticmethod
288+
def _convert_to_numpy(arr):
289+
"""If needed, converts the given array to a numpy array for the pandas
290+
return type in `retrieve`."""
291+
if is_numpy_array(arr):
292+
return arr
293+
elif is_torch_array(arr):
294+
return arr.cpu().detach().numpy()
295+
else:
296+
raise NotImplementedError(
297+
"The pandas return type is currently only supported "
298+
"with numpy and torch arrays.")
299+
264300
def retrieve(self, indices, fields=None, return_type="dict"):
265301
"""Collects data at the given indices.
266302
267303
Args:
268304
indices (array-like): List of indices at which to collect data.
269305
fields (str or array-like of str): List of fields to include. By
270306
default, all fields will be included, with an additional "index"
271-
as the last field ("index" can also be placed anywhere in this
272-
list). This can also be a single str indicating a field name.
307+
as the last field. The "index" field can also be added anywhere
308+
in this list of fields. This argument can also be a single str
309+
indicating a field name.
273310
return_type (str): Type of data to return. See the ``data`` returned
274311
below. Ignored if ``fields`` is a str.
275312
@@ -346,6 +383,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
346383
Like the other return types, the columns can be adjusted with
347384
the ``fields`` parameter.
348385
386+
.. note:: This return type will require copying all fields in
387+
the ArrayStore into NumPy arrays, if they are not already
388+
NumPy arrays.
389+
349390
All data returned by this method will be a copy, i.e., the data will
350391
not update as the store changes.
351392
@@ -354,8 +395,12 @@ def retrieve(self, indices, fields=None, return_type="dict"):
354395
ValueError: Invalid return_type provided.
355396
"""
356397
single_field = isinstance(fields, str)
357-
indices = np.asarray(indices, dtype=np.int32)
358-
occupied = self._props["occupied"][indices] # Induces copy.
398+
indices = self._xp.asarray(indices,
399+
dtype=self._xp.int32,
400+
device=self._device)
401+
402+
# Induces copy (in numpy, at least).
403+
occupied = self._props["occupied"][indices]
359404

360405
if single_field:
361406
data = None
@@ -374,10 +419,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
374419
for name in fields:
375420
# Collect array data.
376421
#
377-
# Note that fancy indexing with indices already creates a copy, so
378-
# only `indices` needs to be copied explicitly.
422+
# Note that fancy indexing with indices already creates a copy (in
423+
# numpy, at least), so only `indices` needs to be copied explicitly.
379424
if name == "index":
380-
arr = np.copy(indices)
425+
arr = self._xp.asarray(indices, copy=True)
381426
elif name in self._fields:
382427
arr = self._fields[name][indices] # Induces copy.
383428
else:
@@ -391,6 +436,8 @@ def retrieve(self, indices, fields=None, return_type="dict"):
391436
elif return_type == "tuple":
392437
data.append(arr)
393438
elif return_type == "pandas":
439+
arr = self._convert_to_numpy(arr)
440+
394441
if len(arr.shape) == 1: # Scalar entries.
395442
data[name] = arr
396443
elif len(arr.shape) == 2: # 1D array entries.
@@ -405,6 +452,8 @@ def retrieve(self, indices, fields=None, return_type="dict"):
405452
if return_type == "tuple":
406453
data = tuple(data)
407454
elif return_type == "pandas":
455+
occupied = self._convert_to_numpy(occupied)
456+
408457
# Data above are already copied, so no need to copy again.
409458
data = ArchiveDataFrame(data, copy=False)
410459

@@ -471,8 +520,16 @@ def add(self, indices, data):
471520
"This can also occur if the archive and result_archive have "
472521
"different extra_fields.")
473522

523+
# Determine the unique indices. These operations are preferred over
524+
# `xp.unique_values(indices)` because they operate in linear time, while
525+
# unique_values usually sorts the input.
526+
indices_occupied = self._xp.zeros(self.capacity,
527+
dtype=bool,
528+
device=self._device)
529+
indices_occupied[indices] = True
530+
unique_indices = self._xp.nonzero(indices_occupied)[0]
531+
474532
# Update occupancy data.
475-
unique_indices = np.where(aggregate(indices, 1, func="len") != 0)[0]
476533
cur_occupied = self._props["occupied"][unique_indices]
477534
new_indices = unique_indices[~cur_occupied]
478535
n_occupied = self._props["n_occupied"]
@@ -483,16 +540,18 @@ def add(self, indices, data):
483540

484541
# Insert into the ArrayStore. Note that we do not assume indices are
485542
# unique. Hence, when updating occupancy data above, we computed the
486-
# unique indices. In contrast, here we let NumPy's default behavior
543+
# unique indices. In contrast, here we let the array's default behavior
487544
# handle duplicate indices.
488545
for name, arr in self._fields.items():
489-
arr[indices] = data[name]
546+
arr[indices] = self._xp.asarray(data[name],
547+
dtype=arr.dtype,
548+
device=self._device)
490549

491550
def clear(self):
492551
"""Removes all entries from the store."""
493552
self._props["updates"][Update.CLEAR] += 1
494553
self._props["n_occupied"] = 0 # Effectively clears occupied_list too.
495-
self._props["occupied"].fill(False)
554+
self._props["occupied"][:] = False
496555

497556
def resize(self, capacity):
498557
"""Resizes the store to the given capacity.
@@ -512,14 +571,20 @@ def resize(self, capacity):
512571
self._props["capacity"] = capacity
513572

514573
cur_occupied = self._props["occupied"]
515-
self._props["occupied"] = np.zeros(capacity, dtype=bool)
574+
self._props["occupied"] = self._xp.zeros(capacity,
575+
dtype=bool,
576+
device=self._device)
516577
self._props["occupied"][:cur_capacity] = cur_occupied
517578

518579
cur_occupied_list = self._props["occupied_list"]
519-
self._props["occupied_list"] = np.empty(capacity, dtype=np.int32)
580+
self._props["occupied_list"] = self._xp.empty(capacity,
581+
dtype=self._xp.int32,
582+
device=self._device)
520583
self._props["occupied_list"][:cur_capacity] = cur_occupied_list
521584

522585
for name, cur_arr in self._fields.items():
523586
new_shape = (capacity,) + cur_arr.shape[1:]
524-
self._fields[name] = np.empty(new_shape, cur_arr.dtype)
587+
self._fields[name] = self._xp.empty(new_shape,
588+
dtype=cur_arr.dtype,
589+
device=self._device)
525590
self._fields[name][:cur_capacity] = cur_arr

0 commit comments

Comments
 (0)