diff --git a/HISTORY.md b/HISTORY.md index b004cbfd4..302dee66e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -6,6 +6,7 @@ #### API +- Add `replace` parameter to sample_elites in archives ({pr}`682`) - Require numpy to be at least 2.0.0 ({pr}`681`) - Allow specifying `centroids` with filenames in CVTArchive ({pr}`679`) - Add `kdtree_query_kwargs` parameter to CVTArchive ({pr}`677`) diff --git a/ribs/archives/_archive_base.py b/ribs/archives/_archive_base.py index 0d5fff265..740a57628 100644 --- a/ribs/archives/_archive_base.py +++ b/ribs/archives/_archive_base.py @@ -364,12 +364,11 @@ def data( """ raise NotImplementedError("`data` has not been implemented in this archive") - def sample_elites(self, n: Int) -> BatchData: + def sample_elites(self, n: Int, replace: bool = True) -> BatchData: """Randomly samples elites from the archive. - Currently, this sampling is done uniformly at random. Furthermore, each sample - is done independently, so elites may be repeated in the sample. Additional - sampling methods may be supported in the future. + Currently, this sampling is done uniformly at random, either with or without + replacement. Additional sampling methods may be supported in the future. Example: :: @@ -382,12 +381,16 @@ def sample_elites(self, n: Int) -> BatchData: Args: n: Number of elites to sample. + replace: Whether to replace the elites when sampling. If True, the elites + will be replaced and thus will be sampled independently. Returns: A batch of elites randomly selected from the archive. Raises: IndexError: The archive is empty. + ValueError: ``n`` was greater than the number of elites in the archive when + ``replace=False``. """ raise NotImplementedError( "`sample_elites` has not been implemented in this archive" diff --git a/ribs/archives/_categorical_archive.py b/ribs/archives/_categorical_archive.py index 6bec4b44f..a6ff11556 100644 --- a/ribs/archives/_categorical_archive.py +++ b/ribs/archives/_categorical_archive.py @@ -789,11 +789,16 @@ def data( ) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame: return self._store.data(fields, return_type) - def sample_elites(self, n: Int) -> BatchData: + def sample_elites(self, n: Int, replace: bool = True) -> BatchData: if self.empty: raise IndexError("No elements in archive.") + if not replace and n > len(self._store): + raise ValueError( + "Cannot take a larger sample than the number of elites " + "in the archive when 'replace=False'" + ) - random_indices = self._rng.integers(len(self._store), size=n) + random_indices = self._rng.choice(len(self._store), size=n, replace=replace) selected_indices = self._store.occupied_list[random_indices] _, elites = self._store.retrieve(selected_indices) return elites diff --git a/ribs/archives/_cvt_archive.py b/ribs/archives/_cvt_archive.py index 46b5579bd..378347261 100644 --- a/ribs/archives/_cvt_archive.py +++ b/ribs/archives/_cvt_archive.py @@ -1104,11 +1104,16 @@ def data( ) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame: return self._store.data(fields, return_type) - def sample_elites(self, n: Int) -> BatchData: + def sample_elites(self, n: Int, replace: bool = True) -> BatchData: if self.empty: raise IndexError("No elements in archive.") + if not replace and n > len(self._store): + raise ValueError( + "Cannot take a larger sample than the number of elites " + "in the archive when 'replace=False'" + ) - random_indices = self._rng.integers(len(self._store), size=n) + random_indices = self._rng.choice(len(self._store), size=n, replace=replace) selected_indices = self._store.occupied_list[random_indices] _, elites = self._store.retrieve(selected_indices) return elites diff --git a/ribs/archives/_grid_archive.py b/ribs/archives/_grid_archive.py index 34665c941..5162dde88 100644 --- a/ribs/archives/_grid_archive.py +++ b/ribs/archives/_grid_archive.py @@ -911,11 +911,16 @@ def data( ) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame: return self._store.data(fields, return_type) - def sample_elites(self, n: Int) -> BatchData: + def sample_elites(self, n: Int, replace: bool = True) -> BatchData: if self.empty: raise IndexError("No elements in archive.") + if not replace and n > len(self._store): + raise ValueError( + "Cannot take a larger sample than the number of elites " + "in the archive when 'replace=False'" + ) - random_indices = self._rng.integers(len(self._store), size=n) + random_indices = self._rng.choice(len(self._store), size=n, replace=replace) selected_indices = self._store.occupied_list[random_indices] _, elites = self._store.retrieve(selected_indices) return elites diff --git a/ribs/archives/_proximity_archive.py b/ribs/archives/_proximity_archive.py index 7908449b5..a57d5c909 100644 --- a/ribs/archives/_proximity_archive.py +++ b/ribs/archives/_proximity_archive.py @@ -849,11 +849,16 @@ def data( ) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame: return self._store.data(fields, return_type) - def sample_elites(self, n: Int) -> BatchData: + def sample_elites(self, n: Int, replace: bool = True) -> BatchData: if self.empty: raise IndexError("No elements in archive.") + if not replace and n > len(self._store): + raise ValueError( + "Cannot take a larger sample than the number of elites " + "in the archive when 'replace=False'" + ) - random_indices = self._rng.integers(len(self._store), size=n) + random_indices = self._rng.choice(len(self._store), size=n, replace=replace) selected_indices = self._store.occupied_list[random_indices] _, elites = self._store.retrieve(selected_indices) return elites diff --git a/ribs/archives/_sliding_boundaries_archive.py b/ribs/archives/_sliding_boundaries_archive.py index 7e49d3258..1dab13af2 100644 --- a/ribs/archives/_sliding_boundaries_archive.py +++ b/ribs/archives/_sliding_boundaries_archive.py @@ -773,11 +773,16 @@ def data( ) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame: return self._store.data(fields, return_type) - def sample_elites(self, n: Int) -> BatchData: + def sample_elites(self, n: Int, replace: bool = True) -> BatchData: if self.empty: raise IndexError("No elements in archive.") + if not replace and n > len(self._store): + raise ValueError( + "Cannot take a larger sample than the number of elites " + "in the archive when 'replace=False'" + ) - random_indices = self._rng.integers(len(self._store), size=n) + random_indices = self._rng.choice(len(self._store), size=n, replace=replace) selected_indices = self._store.occupied_list[random_indices] _, elites = self._store.retrieve(selected_indices) return elites diff --git a/tests/archives/archive_base_test.py b/tests/archives/archive_base_test.py index 916aeed74..ef7773eae 100644 --- a/tests/archives/archive_base_test.py +++ b/tests/archives/archive_base_test.py @@ -593,6 +593,36 @@ def test_sample_elites_fails_when_empty(data): data.archive.sample_elites(1) +@pytest.mark.parametrize("setting_for_n", ["enough_n", "too_many_n"]) +def test_sample_elites_with_replacement(data, setting_for_n): + if isinstance(data.archive, CategoricalArchive): + data.archive.add( + solution=np.zeros((3, 3)), + objective=[1, 2, 3], + measures=[["A", "One"], ["A", "Two"], ["A", "Three"]], + ) + else: + data.archive.add( + solution=np.zeros((3, 3)), + objective=[1, 2, 3], + measures=[[-1, -1], [-1, 1], [1, 1]], + ) + + if setting_for_n == "enough_n": + # Sampling exactly 3 with replace=False should cause the 3 elites to be sampled. + elites = data.archive.sample_elites(3, replace=False) + assert np.allclose(np.sort(elites["objective"]), [1, 2, 3]) + elif setting_for_n == "too_many_n": + # Sampling more than the number of elites with replace=False throws an error. + with pytest.raises( + ValueError, + match=r"Cannot take a larger sample than the number of elites in the archive .*", + ): + elites = data.archive.sample_elites(4, replace=False) + else: + raise ValueError + + @pytest.mark.parametrize("name", ARCHIVE_NAMES) @pytest.mark.parametrize("with_elite", [True, False], ids=["nonempty", "empty"]) @pytest.mark.parametrize("dtype", [np.float64, np.float32], ids=["float64", "float32"])