Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### API

- Add `replace` parameter to sample_elites in archives ({pr}`682`)
- Require numpy to be at least 2.0.0 ({pr}`681`)
- Allow specifying `centroids` with filenames in CVTArchive ({pr}`679`)
- Add `kdtree_query_kwargs` parameter to CVTArchive ({pr}`677`)
Expand Down
11 changes: 7 additions & 4 deletions ribs/archives/_archive_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,11 @@ def data(
"""
raise NotImplementedError("`data` has not been implemented in this archive")

def sample_elites(self, n: Int) -> BatchData:
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
"""Randomly samples elites from the archive.

Currently, this sampling is done uniformly at random. Furthermore, each sample
is done independently, so elites may be repeated in the sample. Additional
sampling methods may be supported in the future.
Currently, this sampling is done uniformly at random, either with or without
replacement. Additional sampling methods may be supported in the future.

Example:
::
Expand All @@ -382,12 +381,16 @@ def sample_elites(self, n: Int) -> BatchData:

Args:
n: Number of elites to sample.
replace: Whether to replace the elites when sampling. If True, the elites
will be replaced and thus will be sampled independently.

Returns:
A batch of elites randomly selected from the archive.

Raises:
IndexError: The archive is empty.
ValueError: ``n`` was greater than the number of elites in the archive when
``replace=False``.
"""
raise NotImplementedError(
"`sample_elites` has not been implemented in this archive"
Expand Down
9 changes: 7 additions & 2 deletions ribs/archives/_categorical_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,11 +789,16 @@ def data(
) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame:
return self._store.data(fields, return_type)

def sample_elites(self, n: Int) -> BatchData:
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
if self.empty:
raise IndexError("No elements in archive.")
if not replace and n > len(self._store):
raise ValueError(
"Cannot take a larger sample than the number of elites "
"in the archive when 'replace=False'"
)

random_indices = self._rng.integers(len(self._store), size=n)
random_indices = self._rng.choice(len(self._store), size=n, replace=replace)
selected_indices = self._store.occupied_list[random_indices]
_, elites = self._store.retrieve(selected_indices)
return elites
9 changes: 7 additions & 2 deletions ribs/archives/_cvt_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,11 +1104,16 @@ def data(
) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame:
return self._store.data(fields, return_type)

def sample_elites(self, n: Int) -> BatchData:
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
if self.empty:
raise IndexError("No elements in archive.")
if not replace and n > len(self._store):
raise ValueError(
"Cannot take a larger sample than the number of elites "
"in the archive when 'replace=False'"
)

random_indices = self._rng.integers(len(self._store), size=n)
random_indices = self._rng.choice(len(self._store), size=n, replace=replace)
selected_indices = self._store.occupied_list[random_indices]
_, elites = self._store.retrieve(selected_indices)
return elites
9 changes: 7 additions & 2 deletions ribs/archives/_grid_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,11 +911,16 @@ def data(
) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame:
return self._store.data(fields, return_type)

def sample_elites(self, n: Int) -> BatchData:
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
if self.empty:
raise IndexError("No elements in archive.")
if not replace and n > len(self._store):
raise ValueError(
"Cannot take a larger sample than the number of elites "
"in the archive when 'replace=False'"
)

random_indices = self._rng.integers(len(self._store), size=n)
random_indices = self._rng.choice(len(self._store), size=n, replace=replace)
selected_indices = self._store.occupied_list[random_indices]
_, elites = self._store.retrieve(selected_indices)
return elites
Expand Down
9 changes: 7 additions & 2 deletions ribs/archives/_proximity_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,11 +849,16 @@ def data(
) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame:
return self._store.data(fields, return_type)

def sample_elites(self, n: Int) -> BatchData:
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
if self.empty:
raise IndexError("No elements in archive.")
if not replace and n > len(self._store):
raise ValueError(
"Cannot take a larger sample than the number of elites "
"in the archive when 'replace=False'"
)

random_indices = self._rng.integers(len(self._store), size=n)
random_indices = self._rng.choice(len(self._store), size=n, replace=replace)
selected_indices = self._store.occupied_list[random_indices]
_, elites = self._store.retrieve(selected_indices)
return elites
9 changes: 7 additions & 2 deletions ribs/archives/_sliding_boundaries_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,16 @@ def data(
) -> np.ndarray | BatchData | tuple[np.ndarray] | ArchiveDataFrame:
return self._store.data(fields, return_type)

def sample_elites(self, n: Int) -> BatchData:
def sample_elites(self, n: Int, replace: bool = True) -> BatchData:
if self.empty:
raise IndexError("No elements in archive.")
if not replace and n > len(self._store):
raise ValueError(
"Cannot take a larger sample than the number of elites "
"in the archive when 'replace=False'"
)

random_indices = self._rng.integers(len(self._store), size=n)
random_indices = self._rng.choice(len(self._store), size=n, replace=replace)
selected_indices = self._store.occupied_list[random_indices]
_, elites = self._store.retrieve(selected_indices)
return elites
30 changes: 30 additions & 0 deletions tests/archives/archive_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,36 @@ def test_sample_elites_fails_when_empty(data):
data.archive.sample_elites(1)


@pytest.mark.parametrize("setting_for_n", ["enough_n", "too_many_n"])
def test_sample_elites_with_replacement(data, setting_for_n):
if isinstance(data.archive, CategoricalArchive):
data.archive.add(
solution=np.zeros((3, 3)),
objective=[1, 2, 3],
measures=[["A", "One"], ["A", "Two"], ["A", "Three"]],
)
else:
data.archive.add(
solution=np.zeros((3, 3)),
objective=[1, 2, 3],
measures=[[-1, -1], [-1, 1], [1, 1]],
)

if setting_for_n == "enough_n":
# Sampling exactly 3 with replace=False should cause the 3 elites to be sampled.
elites = data.archive.sample_elites(3, replace=False)
assert np.allclose(np.sort(elites["objective"]), [1, 2, 3])
elif setting_for_n == "too_many_n":
# Sampling more than the number of elites with replace=False throws an error.
with pytest.raises(
ValueError,
match=r"Cannot take a larger sample than the number of elites in the archive .*",
):
elites = data.archive.sample_elites(4, replace=False)
else:
raise ValueError


@pytest.mark.parametrize("name", ARCHIVE_NAMES)
@pytest.mark.parametrize("with_elite", [True, False], ids=["nonempty", "empty"])
@pytest.mark.parametrize("dtype", [np.float64, np.float32], ids=["float64", "float32"])
Expand Down
Loading