|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
| 9 | +import unittest |
9 | 10 | from logging import Logger |
10 | 11 | from typing import Any |
11 | 12 |
|
|
62 | 63 | from botorch.models.transforms.input import InputTransform, Normalize |
63 | 64 | from botorch.models.transforms.outcome import OutcomeTransform, Standardize |
64 | 65 | from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood |
| 66 | +from pyre_extensions import assert_is_instance |
65 | 67 |
|
66 | 68 | logger: Logger = get_logger(__name__) |
67 | 69 |
|
@@ -337,6 +339,38 @@ def sobol_gpei_generation_node_gs( |
337 | 339 | return sobol_mbm_GS_nodes |
338 | 340 |
|
339 | 341 |
|
| 342 | +def check_sobol_node( |
| 343 | + test_case: unittest.TestCase, |
| 344 | + gs: GenerationStrategy, |
| 345 | + expected_num_trials: int, |
| 346 | + expected_min_trials_observed: int | None = None, |
| 347 | +) -> None: |
| 348 | + """Helper to check common Sobol node properties. |
| 349 | +
|
| 350 | + Args: |
| 351 | + test_case: The test case instance for assertions. |
| 352 | + gs: The generation strategy to check. |
| 353 | + expected_num_trials: The expected number of trials (threshold). |
| 354 | + expected_min_trials_observed: The expected min_trials_observed threshold. |
| 355 | + If None, the check is skipped. |
| 356 | + """ |
| 357 | + sobol_node = gs._nodes[0] |
| 358 | + test_case.assertEqual( |
| 359 | + sobol_node.generator_specs[0].generator_enum, Generators.SOBOL |
| 360 | + ) |
| 361 | + # First MinTrials criterion has the num_trials threshold. |
| 362 | + test_case.assertEqual( |
| 363 | + assert_is_instance(sobol_node.transition_criteria[0], MinTrials).threshold, |
| 364 | + expected_num_trials, |
| 365 | + ) |
| 366 | + if expected_min_trials_observed is not None: |
| 367 | + # Second MinTrials criterion has the min_trials_observed threshold. |
| 368 | + test_case.assertEqual( |
| 369 | + assert_is_instance(sobol_node.transition_criteria[1], MinTrials).threshold, |
| 370 | + expected_min_trials_observed, |
| 371 | + ) |
| 372 | + |
| 373 | + |
340 | 374 | def get_sobol_MBM_MTGP_gs() -> GenerationStrategy: |
341 | 375 | return GenerationStrategy( |
342 | 376 | nodes=[ |
|
0 commit comments