Skip to content

Commit b6b9869

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Introduce check_sobol_node test helper
Summary: Simplifies common checks for num initialization trials etc that we do for node based GS Differential Revision: D89932177 Privacy Context Container: L1307644
1 parent ca36693 commit b6b9869

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

ax/utils/testing/modeling_stubs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import unittest
910
from logging import Logger
1011
from typing import Any
1112

@@ -62,6 +63,7 @@
6263
from botorch.models.transforms.input import InputTransform, Normalize
6364
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
6465
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
66+
from pyre_extensions import assert_is_instance
6567

6668
logger: Logger = get_logger(__name__)
6769

@@ -337,6 +339,38 @@ def sobol_gpei_generation_node_gs(
337339
return sobol_mbm_GS_nodes
338340

339341

342+
def check_sobol_node(
343+
test_case: unittest.TestCase,
344+
gs: GenerationStrategy,
345+
expected_num_trials: int,
346+
expected_min_trials_observed: int | None = None,
347+
) -> None:
348+
"""Helper to check common Sobol node properties.
349+
350+
Args:
351+
test_case: The test case instance for assertions.
352+
gs: The generation strategy to check.
353+
expected_num_trials: The expected number of trials (threshold).
354+
expected_min_trials_observed: The expected min_trials_observed threshold.
355+
If None, the check is skipped.
356+
"""
357+
sobol_node = gs._nodes[0]
358+
test_case.assertEqual(
359+
sobol_node.generator_specs[0].generator_enum, Generators.SOBOL
360+
)
361+
# First MinTrials criterion has the num_trials threshold.
362+
test_case.assertEqual(
363+
assert_is_instance(sobol_node.transition_criteria[0], MinTrials).threshold,
364+
expected_num_trials,
365+
)
366+
if expected_min_trials_observed is not None:
367+
# Second MinTrials criterion has the min_trials_observed threshold.
368+
test_case.assertEqual(
369+
assert_is_instance(sobol_node.transition_criteria[1], MinTrials).threshold,
370+
expected_min_trials_observed,
371+
)
372+
373+
340374
def get_sobol_MBM_MTGP_gs() -> GenerationStrategy:
341375
return GenerationStrategy(
342376
nodes=[

0 commit comments

Comments
 (0)