Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import unittest
from logging import Logger
from typing import Any

Expand Down Expand Up @@ -62,6 +63,7 @@
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from pyre_extensions import assert_is_instance

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -337,6 +339,40 @@ def sobol_gpei_generation_node_gs(
return sobol_mbm_GS_nodes


def check_sobol_node(
test_case: unittest.TestCase,
gs: GenerationStrategy,
expected_num_trials: int,
expected_min_trials_observed: int | None = None,
) -> None:
"""Helper to check common Sobol node properties.

Args:
test_case: The test case instance for assertions.
gs: The generation strategy to check.
expected_num_trials: The expected number of trials that need to be generated
before the transition to the next node.
expected_min_trials_observed: The expected number of trial that needs to be
observed (i.e., completed) before the transition to the next node.
If None, the check is skipped.
"""
sobol_node = gs._nodes[0]
test_case.assertEqual(
sobol_node.generator_specs[0].generator_enum, Generators.SOBOL
)
# First MinTrials criterion has the num_trials threshold.
test_case.assertEqual(
assert_is_instance(sobol_node.transition_criteria[0], MinTrials).threshold,
expected_num_trials,
)
if expected_min_trials_observed is not None:
# Second MinTrials criterion has the min_trials_observed threshold.
test_case.assertEqual(
assert_is_instance(sobol_node.transition_criteria[1], MinTrials).threshold,
expected_min_trials_observed,
)


def get_sobol_MBM_MTGP_gs() -> GenerationStrategy:
return GenerationStrategy(
nodes=[
Expand Down