Skip to content

Commit 44b57e8

Browse files
authored
add tests
1 parent 1169da5 commit 44b57e8

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

bayes_opt/acquisition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def __init__(
613613
random_state: int | RandomState | None = None,
614614
) -> None:
615615
if xi < 0:
616-
error_msg = "xi must be greater than equal to 0."
616+
error_msg = "xi must be greater than or equal to 0."
617617
raise ValueError(error_msg)
618618
if exploration_decay is not None and not (0 < exploration_decay <= 1):
619619
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."
@@ -800,7 +800,7 @@ def __init__(
800800
random_state: int | RandomState | None = None,
801801
) -> None:
802802
if xi < 0:
803-
error_msg = "xi must be greater than equal to 0."
803+
error_msg = "xi must be greater than or equal to 0."
804804
raise ValueError(error_msg)
805805
if exploration_decay is not None and not (0 < exploration_decay <= 1):
806806
error_msg = "exploration_decay must be greater than 0 and less than or equal to 1."

tests/test_acquisition.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,66 @@ def test_upper_confidence_bound_invalid_kappa_error(kappa: float):
377377
acquisition.UpperConfidenceBound(kappa=kappa)
378378

379379

380+
@pytest.mark.parametrize("exploration_decay", [-0.1, 0.0, 1.1, 2.0, np.inf])
381+
def test_upper_confidence_bound_invalid_exploration_decay_error(exploration_decay: float):
382+
with pytest.raises(
383+
ValueError, match="exploration_decay must be greater than 0 and less than or equal to 1."
384+
):
385+
acquisition.UpperConfidenceBound(kappa=1.0, exploration_decay=exploration_decay)
386+
387+
388+
@pytest.mark.parametrize("exploration_decay_delay", [-1, -10, "not_an_int", 1.5])
389+
def test_upper_confidence_bound_invalid_exploration_decay_delay_error(exploration_decay_delay):
390+
with pytest.raises(
391+
ValueError, match="exploration_decay_delay must be an integer greater than or equal to 0."
392+
):
393+
acquisition.UpperConfidenceBound(kappa=1.0, exploration_decay_delay=exploration_decay_delay)
394+
395+
396+
@pytest.mark.parametrize("xi", [-0.1, -1.0, -np.inf])
397+
def test_probability_of_improvement_invalid_xi_error(xi: float):
398+
with pytest.raises(ValueError, match="xi must be greater than or equal to 0."):
399+
acquisition.ProbabilityOfImprovement(xi=xi)
400+
401+
402+
@pytest.mark.parametrize("exploration_decay", [-0.1, 0.0, 1.1, 2.0, np.inf])
403+
def test_probability_of_improvement_invalid_exploration_decay_error(exploration_decay: float):
404+
with pytest.raises(
405+
ValueError, match="exploration_decay must be greater than 0 and less than or equal to 1."
406+
):
407+
acquisition.ProbabilityOfImprovement(xi=0.01, exploration_decay=exploration_decay)
408+
409+
410+
@pytest.mark.parametrize("exploration_decay_delay", [-1, -10, "not_an_int", 1.5])
411+
def test_probability_of_improvement_invalid_exploration_decay_delay_error(exploration_decay_delay):
412+
with pytest.raises(
413+
ValueError, match="exploration_decay_delay must be an integer greater than or equal to 0."
414+
):
415+
acquisition.ProbabilityOfImprovement(xi=0.01, exploration_decay_delay=exploration_decay_delay)
416+
417+
418+
@pytest.mark.parametrize("xi", [-0.1, -1.0, -np.inf])
419+
def test_expected_improvement_invalid_xi_error(xi: float):
420+
with pytest.raises(ValueError, match="xi must be greater than or equal to 0."):
421+
acquisition.ExpectedImprovement(xi=xi)
422+
423+
424+
@pytest.mark.parametrize("exploration_decay", [-0.1, 0.0, 1.1, 2.0, np.inf])
425+
def test_expected_improvement_invalid_exploration_decay_error(exploration_decay: float):
426+
with pytest.raises(
427+
ValueError, match="exploration_decay must be greater than 0 and less than or equal to 1."
428+
):
429+
acquisition.ExpectedImprovement(xi=0.01, exploration_decay=exploration_decay)
430+
431+
432+
@pytest.mark.parametrize("exploration_decay_delay", [-1, -10, "not_an_int", 1.5])
433+
def test_expected_improvement_invalid_exploration_decay_delay_error(exploration_decay_delay):
434+
with pytest.raises(
435+
ValueError, match="exploration_decay_delay must be an integer greater than or equal to 0."
436+
):
437+
acquisition.ExpectedImprovement(xi=0.01, exploration_decay_delay=exploration_decay_delay)
438+
439+
380440
def verify_optimizers_match(optimizer1, optimizer2):
381441
"""Helper function to verify two optimizers match."""
382442
assert len(optimizer1.space) == len(optimizer2.space)

0 commit comments

Comments
 (0)