|
2 | 2 | import numpy as np
|
3 | 3 | import pytest
|
4 | 4 |
|
5 |
| -from ribs.archives import GridArchive, ProximityArchive |
| 5 | +from ribs.archives import CategoricalArchive, GridArchive, ProximityArchive |
6 | 6 | from ribs.emitters import GaussianEmitter
|
7 | 7 | from ribs.schedulers import BanditScheduler, Scheduler
|
8 | 8 |
|
@@ -431,3 +431,39 @@ def test_constant_active_emitters_bandit_scheduler():
|
431 | 431 | scheduler.tell(objective, measures)
|
432 | 432 |
|
433 | 433 | assert scheduler.active.sum() == expected_active
|
| 434 | + |
| 435 | + |
| 436 | +def test_scheduler_with_categorical_archive(add_mode): |
| 437 | + batch_size = 4 |
| 438 | + archive = CategoricalArchive( |
| 439 | + solution_dim=2, |
| 440 | + categories=[ |
| 441 | + ["A", "B", "C"], |
| 442 | + ["One", "Two", "Three", "Four"], |
| 443 | + ], |
| 444 | + dtype={ |
| 445 | + "solution": np.float32, |
| 446 | + "objective": np.float32, |
| 447 | + "measures": object, |
| 448 | + }, |
| 449 | + ) |
| 450 | + emitters = [ |
| 451 | + GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=batch_size) |
| 452 | + ] |
| 453 | + scheduler = Scheduler(archive, emitters, add_mode=add_mode) |
| 454 | + |
| 455 | + measures_batch = [["A", "Four"], ["B", "Three"], ["C", "One"], ["C", "Two"]] |
| 456 | + |
| 457 | + _ = scheduler.ask() # Ignore the actual values of the solutions. |
| 458 | + # We pass in 4 solutions with unique measures, so all should go into |
| 459 | + # the archive. |
| 460 | + scheduler.tell(np.ones(batch_size), measures_batch) |
| 461 | + |
| 462 | + print(archive.data()) |
| 463 | + |
| 464 | + assert_archive_elites( |
| 465 | + archive=scheduler.archive, |
| 466 | + batch_size=batch_size, |
| 467 | + objective_batch=np.ones(batch_size), |
| 468 | + measures_batch=measures_batch, |
| 469 | + ) |
0 commit comments