4
4
from enum import IntEnum
5
5
from functools import cached_property
6
6
7
- import numpy as np
8
- from numpy_groupies import aggregate_nb as aggregate
7
+ from array_api_compat import is_numpy_array , is_numpy_namespace , is_torch_array
9
8
10
- from ribs ._utils import readonly
9
+ from ribs ._utils import arr_readonly , xp_namespace
11
10
from ribs .archives ._archive_data_frame import ArchiveDataFrame
12
11
13
12
@@ -36,7 +35,7 @@ def __next__(self):
36
35
37
36
Raises RuntimeError if the store was modified.
38
37
"""
39
- if not np . all ( self .state == self .store ._props ["updates" ]) :
38
+ if self .state != self .store ._props ["updates" ]:
40
39
# This check should go before the StopIteration check because a call
41
40
# to clear() would cause the len(self.store) to be 0 and thus
42
41
# trigger StopIteration.
@@ -61,8 +60,8 @@ class ArrayStore:
61
60
"""Maintains a set of arrays that share a common dimension.
62
61
63
62
The ArrayStore consists of several *fields* of data that are manipulated
64
- simultaneously via batch operations. Each field is a NumPy array with a
65
- dimension of ``(capacity, ...)`` and can be of any type.
63
+ simultaneously via batch operations. Each field is an array with a dimension
64
+ of ``(capacity, ...)`` and can be of any type.
66
65
67
66
Since the arrays all share a common first dimension, they also share a
68
67
common index. For instance, if we :meth:`retrieve` the data at indices ``[0,
@@ -77,6 +76,12 @@ class ArrayStore:
77
76
The ArrayStore supports several further operations, such as an :meth:`add`
78
77
method that inserts data into the ArrayStore.
79
78
79
+ By default, the arrays in the ArrayStore are NumPy arrays. However, through
80
+ support for the `Python array API standard
81
+ <https://data-apis.org/array-api/latest/>`_, it is possible to use arrays
82
+ from other libraries like PyTorch by passing in arguments for ``xp`` and
83
+ ``device``.
84
+
80
85
Args:
81
86
field_desc (dict): Description of fields in the array store. The
82
87
description is a dict mapping from a str to a tuple of ``(shape,
@@ -86,6 +91,10 @@ class ArrayStore:
86
91
``(capacity, 10)``. Note that field names must be valid Python
87
92
identifiers.
88
93
capacity (int): Total possible entries in the store.
94
+ xp (array_namespace): Optional array namespace. Should be compatible
95
+ with the array API standard, or supported by array-api-compat.
96
+ Defaults to ``numpy``.
97
+ device (device): Device for arrays.
89
98
90
99
Attributes:
91
100
_props (dict): Properties that are common to every ArrayStore.
@@ -97,7 +106,7 @@ class ArrayStore:
97
106
* "occupied_list": Array of size ``(capacity,)`` listing all
98
107
occupied indices in the store. Only the first ``n_occupied``
99
108
elements will be valid.
100
- * "updates": Int array recording number of calls to functions that
109
+ * "updates": Int list recording number of calls to functions that
101
110
modified the store.
102
111
103
112
_fields (dict): Holds all the arrays with their data.
@@ -109,13 +118,22 @@ class ArrayStore:
109
118
valid Python identifier.
110
119
"""
111
120
112
- def __init__ (self , field_desc , capacity ):
121
+ def __init__ (self , field_desc , capacity , xp = None , device = None ):
122
+ self ._xp = xp_namespace (xp )
123
+ self ._device = device
124
+
113
125
self ._props = {
114
- "capacity" : capacity ,
115
- "occupied" : np .zeros (capacity , dtype = bool ),
116
- "n_occupied" : 0 ,
117
- "occupied_list" : np .empty (capacity , dtype = np .int32 ),
118
- "updates" : np .array ([0 , 0 ]),
126
+ "capacity" :
127
+ capacity ,
128
+ "occupied" :
129
+ self ._xp .zeros (capacity , dtype = bool , device = self ._device ),
130
+ "n_occupied" :
131
+ 0 ,
132
+ "occupied_list" :
133
+ self ._xp .empty (capacity ,
134
+ dtype = self ._xp .int32 ,
135
+ device = self ._device ),
136
+ "updates" : [0 , 0 ],
119
137
}
120
138
121
139
self ._fields = {}
@@ -130,7 +148,9 @@ def __init__(self, field_desc, capacity):
130
148
field_shape = (field_shape ,)
131
149
132
150
array_shape = (capacity ,) + tuple (field_shape )
133
- self ._fields [name ] = np .empty (array_shape , dtype )
151
+ self ._fields [name ] = self ._xp .empty (array_shape ,
152
+ dtype = dtype ,
153
+ device = self ._device )
134
154
135
155
def __len__ (self ):
136
156
"""Number of occupied indices in the store, i.e., number of indices that
@@ -163,15 +183,14 @@ def capacity(self):
163
183
164
184
@property
165
185
def occupied (self ):
166
- """numpy.ndarray : Boolean array of size ``(capacity,)`` indicating
167
- whether each index has a data entry."""
168
- return readonly (self ._props ["occupied" ]. view () )
186
+ """array : Boolean array of size ``(capacity,)`` indicating whether each
187
+ index has a data entry."""
188
+ return arr_readonly (self ._props ["occupied" ])
169
189
170
190
@property
171
191
def occupied_list (self ):
172
- """numpy.ndarray: int32 array listing all occupied indices in the
173
- store."""
174
- return readonly (
192
+ """array: int32 array listing all occupied indices in the store."""
193
+ return arr_readonly (
175
194
self ._props ["occupied_list" ][:self ._props ["n_occupied" ]])
176
195
177
196
@cached_property
@@ -211,10 +230,14 @@ def dtypes(self):
211
230
"measures": np.float32,
212
231
}
213
232
"""
214
- # Calling `.type` retrieves the numpy scalar type, which is callable:
215
- # - https://numpy.org/doc/stable/reference/arrays.scalars.html
216
- # - https://numpy.org/doc/stable/reference/arrays.dtypes.html
217
- return {name : arr .dtype .type for name , arr in self ._fields .items ()}
233
+ if is_numpy_namespace (self ._xp ):
234
+ # TODO (#577): In NumPy, we currently want the scalar type (i.e.,
235
+ # arr.dtype.type rather than arr.dtype), which is callable.
236
+ # Ultimately, this should be switched to just be the dtype to be
237
+ # compatible across array libraries.
238
+ return {name : arr .dtype .type for name , arr in self ._fields .items ()}
239
+ else :
240
+ return {name : arr .dtype for name , arr in self ._fields .items ()}
218
241
219
242
@cached_property
220
243
def dtypes_with_index (self ):
@@ -230,7 +253,7 @@ def dtypes_with_index(self):
230
253
"index": np.int32,
231
254
}
232
255
"""
233
- return self .dtypes | {"index" : np .int32 }
256
+ return self .dtypes | {"index" : self . _xp .int32 }
234
257
235
258
@cached_property
236
259
def field_list (self ):
@@ -261,15 +284,29 @@ def field_list_with_index(self):
261
284
"""
262
285
return list (self ._fields ) + ["index" ]
263
286
287
+ @staticmethod
288
+ def _convert_to_numpy (arr ):
289
+ """If needed, converts the given array to a numpy array for the pandas
290
+ return type in `retrieve`."""
291
+ if is_numpy_array (arr ):
292
+ return arr
293
+ elif is_torch_array (arr ):
294
+ return arr .cpu ().detach ().numpy ()
295
+ else :
296
+ raise NotImplementedError (
297
+ "The pandas return type is currently only supported "
298
+ "with numpy and torch arrays." )
299
+
264
300
def retrieve (self , indices , fields = None , return_type = "dict" ):
265
301
"""Collects data at the given indices.
266
302
267
303
Args:
268
304
indices (array-like): List of indices at which to collect data.
269
305
fields (str or array-like of str): List of fields to include. By
270
306
default, all fields will be included, with an additional "index"
271
- as the last field ("index" can also be placed anywhere in this
272
- list). This can also be a single str indicating a field name.
307
+ as the last field. The "index" field can also be added anywhere
308
+ in this list of fields. This argument can also be a single str
309
+ indicating a field name.
273
310
return_type (str): Type of data to return. See the ``data`` returned
274
311
below. Ignored if ``fields`` is a str.
275
312
@@ -346,6 +383,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
346
383
Like the other return types, the columns can be adjusted with
347
384
the ``fields`` parameter.
348
385
386
+ .. note:: This return type will require copying all fields in
387
+ the ArrayStore into NumPy arrays, if they are not already
388
+ NumPy arrays.
389
+
349
390
All data returned by this method will be a copy, i.e., the data will
350
391
not update as the store changes.
351
392
@@ -354,8 +395,12 @@ def retrieve(self, indices, fields=None, return_type="dict"):
354
395
ValueError: Invalid return_type provided.
355
396
"""
356
397
single_field = isinstance (fields , str )
357
- indices = np .asarray (indices , dtype = np .int32 )
358
- occupied = self ._props ["occupied" ][indices ] # Induces copy.
398
+ indices = self ._xp .asarray (indices ,
399
+ dtype = self ._xp .int32 ,
400
+ device = self ._device )
401
+
402
+ # Induces copy (in numpy, at least).
403
+ occupied = self ._props ["occupied" ][indices ]
359
404
360
405
if single_field :
361
406
data = None
@@ -374,10 +419,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
374
419
for name in fields :
375
420
# Collect array data.
376
421
#
377
- # Note that fancy indexing with indices already creates a copy, so
378
- # only `indices` needs to be copied explicitly.
422
+ # Note that fancy indexing with indices already creates a copy (in
423
+ # numpy, at least), so only `indices` needs to be copied explicitly.
379
424
if name == "index" :
380
- arr = np . copy (indices )
425
+ arr = self . _xp . asarray (indices , copy = True )
381
426
elif name in self ._fields :
382
427
arr = self ._fields [name ][indices ] # Induces copy.
383
428
else :
@@ -391,6 +436,8 @@ def retrieve(self, indices, fields=None, return_type="dict"):
391
436
elif return_type == "tuple" :
392
437
data .append (arr )
393
438
elif return_type == "pandas" :
439
+ arr = self ._convert_to_numpy (arr )
440
+
394
441
if len (arr .shape ) == 1 : # Scalar entries.
395
442
data [name ] = arr
396
443
elif len (arr .shape ) == 2 : # 1D array entries.
@@ -405,6 +452,8 @@ def retrieve(self, indices, fields=None, return_type="dict"):
405
452
if return_type == "tuple" :
406
453
data = tuple (data )
407
454
elif return_type == "pandas" :
455
+ occupied = self ._convert_to_numpy (occupied )
456
+
408
457
# Data above are already copied, so no need to copy again.
409
458
data = ArchiveDataFrame (data , copy = False )
410
459
@@ -471,8 +520,16 @@ def add(self, indices, data):
471
520
"This can also occur if the archive and result_archive have "
472
521
"different extra_fields." )
473
522
523
+ # Determine the unique indices. These operations are preferred over
524
+ # `xp.unique_values(indices)` because they operate in linear time, while
525
+ # unique_values usually sorts the input.
526
+ indices_occupied = self ._xp .zeros (self .capacity ,
527
+ dtype = bool ,
528
+ device = self ._device )
529
+ indices_occupied [indices ] = True
530
+ unique_indices = self ._xp .nonzero (indices_occupied )[0 ]
531
+
474
532
# Update occupancy data.
475
- unique_indices = np .where (aggregate (indices , 1 , func = "len" ) != 0 )[0 ]
476
533
cur_occupied = self ._props ["occupied" ][unique_indices ]
477
534
new_indices = unique_indices [~ cur_occupied ]
478
535
n_occupied = self ._props ["n_occupied" ]
@@ -483,16 +540,18 @@ def add(self, indices, data):
483
540
484
541
# Insert into the ArrayStore. Note that we do not assume indices are
485
542
# unique. Hence, when updating occupancy data above, we computed the
486
- # unique indices. In contrast, here we let NumPy 's default behavior
543
+ # unique indices. In contrast, here we let the array 's default behavior
487
544
# handle duplicate indices.
488
545
for name , arr in self ._fields .items ():
489
- arr [indices ] = data [name ]
546
+ arr [indices ] = self ._xp .asarray (data [name ],
547
+ dtype = arr .dtype ,
548
+ device = self ._device )
490
549
491
550
def clear (self ):
492
551
"""Removes all entries from the store."""
493
552
self ._props ["updates" ][Update .CLEAR ] += 1
494
553
self ._props ["n_occupied" ] = 0 # Effectively clears occupied_list too.
495
- self ._props ["occupied" ]. fill ( False )
554
+ self ._props ["occupied" ][:] = False
496
555
497
556
def resize (self , capacity ):
498
557
"""Resizes the store to the given capacity.
@@ -512,14 +571,20 @@ def resize(self, capacity):
512
571
self ._props ["capacity" ] = capacity
513
572
514
573
cur_occupied = self ._props ["occupied" ]
515
- self ._props ["occupied" ] = np .zeros (capacity , dtype = bool )
574
+ self ._props ["occupied" ] = self ._xp .zeros (capacity ,
575
+ dtype = bool ,
576
+ device = self ._device )
516
577
self ._props ["occupied" ][:cur_capacity ] = cur_occupied
517
578
518
579
cur_occupied_list = self ._props ["occupied_list" ]
519
- self ._props ["occupied_list" ] = np .empty (capacity , dtype = np .int32 )
580
+ self ._props ["occupied_list" ] = self ._xp .empty (capacity ,
581
+ dtype = self ._xp .int32 ,
582
+ device = self ._device )
520
583
self ._props ["occupied_list" ][:cur_capacity ] = cur_occupied_list
521
584
522
585
for name , cur_arr in self ._fields .items ():
523
586
new_shape = (capacity ,) + cur_arr .shape [1 :]
524
- self ._fields [name ] = np .empty (new_shape , cur_arr .dtype )
587
+ self ._fields [name ] = self ._xp .empty (new_shape ,
588
+ dtype = cur_arr .dtype ,
589
+ device = self ._device )
525
590
self ._fields [name ][:cur_capacity ] = cur_arr
0 commit comments