Skip to content

Commit 64e6448

Browse files
authored
Use dtypes when calling asarray in archive methods (#673)
## Description <!-- Provide a brief description of the PR's purpose here. --> Added dtypes for np.asarray in methods like retrieve and index_of across all the archives. Credit to #664 for pointing this out in DNSArchive's retrieve method. ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have linted and formatted my code with `ruff` and `ty` - [x] I have tested my code by running `pytest` - [x] I have added a description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go
1 parent 6ab830f commit 64e6448

File tree

6 files changed

+19
-18
lines changed

6 files changed

+19
-18
lines changed

HISTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
#### Improvements
3434

35+
- Use dtypes when calling asarray in archive methods ({pr}`673`)
3536
- Replace cKDTree usage with KDTree ({pr}`669`)
3637
- Support ProximityArchive in parallel_axes_plot ({pr}`647`)
3738
- Cast dtype when validating arguments ({pr}`646`)

ribs/archives/_categorical_archive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def clear(self) -> None:
738738
## Refer to ArchiveBase for documentation of these methods. ##
739739

740740
def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
741-
measures = np.asarray(measures)
741+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
742742
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
743743

744744
occupied, data = self._store.retrieve(self.index_of(measures))
@@ -747,7 +747,7 @@ def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
747747
return occupied, data
748748

749749
def retrieve_single(self, measures: ArrayLike) -> tuple[bool, SingleData]:
750-
measures = np.asarray(measures)
750+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
751751
check_shape(measures, "measures", self.measure_dim, "measure_dim")
752752

753753
occupied, data = self.retrieve(measures[None])

ribs/archives/_cvt_archive.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def index_of(self, measures: ArrayLike) -> np.ndarray:
541541
ValueError: ``measures`` is not of shape (batch_size, :attr:`measure_dim`).
542542
ValueError: ``measures`` has non-finite values (inf or NaN).
543543
"""
544-
measures = np.asarray(measures)
544+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
545545
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
546546
check_finite(measures, "measures")
547547

@@ -587,7 +587,7 @@ def index_of_single(self, measures: ArrayLike) -> Int:
587587
ValueError: ``measures`` is not of shape (:attr:`measure_dim`,).
588588
ValueError: ``measures`` has non-finite values (inf or NaN).
589589
"""
590-
measures = np.asarray(measures)
590+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
591591
check_shape(measures, "measures", self.measure_dim, "measure_dim")
592592
check_finite(measures, "measures")
593593
return self.index_of(measures[None])[0]
@@ -980,7 +980,7 @@ def clear(self) -> None:
980980
## Refer to ArchiveBase for documentation of these methods. ##
981981

982982
def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
983-
measures = np.asarray(measures)
983+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
984984
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
985985
check_finite(measures, "measures")
986986

@@ -990,7 +990,7 @@ def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
990990
return occupied, data
991991

992992
def retrieve_single(self, measures: ArrayLike) -> tuple[bool, SingleData]:
993-
measures = np.asarray(measures)
993+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
994994
check_shape(measures, "measures", self.measure_dim, "measure_dim")
995995
check_finite(measures, "measures")
996996

ribs/archives/_grid_archive.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def index_of(self, measures: ArrayLike) -> np.ndarray:
394394
ValueError: ``measures`` is not of shape (batch_size, :attr:`measure_dim`).
395395
ValueError: ``measures`` has non-finite values (inf or NaN).
396396
"""
397-
measures = np.asarray(measures)
397+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
398398
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
399399
check_finite(measures, "measures")
400400

@@ -426,7 +426,7 @@ def index_of_single(self, measures: ArrayLike) -> Int:
426426
ValueError: ``measures`` is not of shape (:attr:`measure_dim`,).
427427
ValueError: ``measures`` has non-finite values (inf or NaN).
428428
"""
429-
measures = np.asarray(measures)
429+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
430430
check_shape(measures, "measures", self.measure_dim, "measure_dim")
431431
check_finite(measures, "measures")
432432
return self.index_of(measures[None])[0]
@@ -858,7 +858,7 @@ def clear(self) -> None:
858858
## Refer to ArchiveBase for documentation of these methods. ##
859859

860860
def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
861-
measures = np.asarray(measures)
861+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
862862
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
863863
check_finite(measures, "measures")
864864

@@ -868,7 +868,7 @@ def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
868868
return occupied, data
869869

870870
def retrieve_single(self, measures: ArrayLike) -> tuple[bool, SingleData]:
871-
measures = np.asarray(measures)
871+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
872872
check_shape(measures, "measures", self.measure_dim, "measure_dim")
873873
check_finite(measures, "measures")
874874

ribs/archives/_proximity_archive.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def index_of(self, measures: ArrayLike) -> np.ndarray:
353353
ValueError: ``measures`` is not of shape (batch_size, :attr:`measure_dim`).
354354
ValueError: ``measures`` has non-finite values (inf or NaN).
355355
"""
356-
measures = np.asarray(measures)
356+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
357357
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
358358
check_finite(measures, "measures")
359359

@@ -383,7 +383,7 @@ def index_of_single(self, measures: ArrayLike) -> Int:
383383
ValueError: ``measures`` is not of shape (:attr:`measure_dim`,).
384384
ValueError: ``measures`` has non-finite values (inf or NaN).
385385
"""
386-
measures = np.asarray(measures)
386+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
387387
check_shape(measures, "measures", self.measure_dim, "measure_dim")
388388
check_finite(measures, "measures")
389389
return self.index_of(measures[None])[0]
@@ -788,7 +788,7 @@ def clear(self) -> None:
788788
## Refer to ArchiveBase for documentation of these methods. ##
789789

790790
def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
791-
measures = np.asarray(measures)
791+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
792792
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
793793
check_finite(measures, "measures")
794794

@@ -798,7 +798,7 @@ def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
798798
return occupied, data
799799

800800
def retrieve_single(self, measures: ArrayLike) -> tuple[bool, SingleData]:
801-
measures = np.asarray(measures)
801+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
802802
check_shape(measures, "measures", self.measure_dim, "measure_dim")
803803
check_finite(measures, "measures")
804804

ribs/archives/_sliding_boundaries_archive.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def index_of(self, measures: ArrayLike) -> np.ndarray:
457457
Raises:
458458
ValueError: ``measures`` is not of shape (batch_size, :attr:`measure_dim`).
459459
"""
460-
measures = np.asarray(measures)
460+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
461461
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
462462
check_finite(measures, "measures")
463463

@@ -498,7 +498,7 @@ def index_of_single(self, measures: ArrayLike) -> Int:
498498
ValueError: ``measures`` is not of shape (:attr:`measure_dim`,).
499499
ValueError: ``measures`` has non-finite values (inf or NaN).
500500
"""
501-
measures = np.asarray(measures)
501+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
502502
check_shape(measures, "measures", self.measure_dim, "measure_dim")
503503
check_finite(measures, "measures")
504504
return self.index_of(measures[None])[0]
@@ -720,7 +720,7 @@ def clear(self) -> None:
720720
## Refer to ArchiveBase for documentation of these methods. ##
721721

722722
def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
723-
measures = np.asarray(measures)
723+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
724724
check_batch_shape(measures, "measures", self.measure_dim, "measure_dim")
725725
check_finite(measures, "measures")
726726

@@ -730,7 +730,7 @@ def retrieve(self, measures: ArrayLike) -> tuple[np.ndarray, BatchData]:
730730
return occupied, data
731731

732732
def retrieve_single(self, measures: ArrayLike) -> tuple[bool, SingleData]:
733-
measures = np.asarray(measures)
733+
measures = np.asarray(measures, dtype=self.dtypes["measures"])
734734
check_shape(measures, "measures", self.measure_dim, "measure_dim")
735735
check_finite(measures, "measures")
736736

0 commit comments

Comments
 (0)