Skip to content

Commit 2618657

Browse files
authored
Add test for Scheduler with CategoricalArchive (#569)
## Description <!-- Provide a brief description of the PR's purpose here. --> ## TODO <!-- Notable points that this PR has either accomplished or will accomplish. --> - [x] Add test that calls Scheduler with CategoricalArchive - [x] Fix assert_archive_elites in GridArchive tests ## Status - [x] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [x] I have formatted my code using `yapf` - [x] I have tested my code by running `pytest` - [x] I have linted my code with `pylint` - [N/A] I have added a one-line description of my change to the changelog in `HISTORY.md` - [x] This PR is ready to go
1 parent 5202e61 commit 2618657

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

tests/archives/grid_archive_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def assert_archive_elites(
6969
data["objective"][j], objective_batch[i]))
7070

7171
if measures_batch is not None:
72-
if data["solution"].dtype.kind == "f":
72+
if data["measures"].dtype.kind == "f":
7373
measures_match = np.allclose(data["measures"][j],
7474
measures_batch[i])
7575
else:

tests/schedulers/scheduler_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pytest
44

5-
from ribs.archives import GridArchive, ProximityArchive
5+
from ribs.archives import CategoricalArchive, GridArchive, ProximityArchive
66
from ribs.emitters import GaussianEmitter
77
from ribs.schedulers import BanditScheduler, Scheduler
88

@@ -431,3 +431,39 @@ def test_constant_active_emitters_bandit_scheduler():
431431
scheduler.tell(objective, measures)
432432

433433
assert scheduler.active.sum() == expected_active
434+
435+
436+
def test_scheduler_with_categorical_archive(add_mode):
437+
batch_size = 4
438+
archive = CategoricalArchive(
439+
solution_dim=2,
440+
categories=[
441+
["A", "B", "C"],
442+
["One", "Two", "Three", "Four"],
443+
],
444+
dtype={
445+
"solution": np.float32,
446+
"objective": np.float32,
447+
"measures": object,
448+
},
449+
)
450+
emitters = [
451+
GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=batch_size)
452+
]
453+
scheduler = Scheduler(archive, emitters, add_mode=add_mode)
454+
455+
measures_batch = [["A", "Four"], ["B", "Three"], ["C", "One"], ["C", "Two"]]
456+
457+
_ = scheduler.ask() # Ignore the actual values of the solutions.
458+
# We pass in 4 solutions with unique measures, so all should go into
459+
# the archive.
460+
scheduler.tell(np.ones(batch_size), measures_batch)
461+
462+
print(archive.data())
463+
464+
assert_archive_elites(
465+
archive=scheduler.archive,
466+
batch_size=batch_size,
467+
objective_batch=np.ones(batch_size),
468+
measures_batch=measures_batch,
469+
)

0 commit comments

Comments
 (0)