Skip to content

Commit 2b16f4e

Browse files
bowiechenmeta-codesync[bot]
authored andcommitted
apply Black 25.11.0 style in fbcode (77/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476253 fbshipit-source-id: 1d5d238886f1ef4dd75646fede3890699740aa3c
1 parent e2f8188 commit 2b16f4e

File tree

135 files changed

+642
-587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

135 files changed

+642
-587
lines changed

ax/adapter/tests/test_base_adapter.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -619,10 +619,13 @@ def test_set_status_quo_with_multiple_observations(self) -> None:
619619
# Fetch constraint metric an additional time. This will lead to two
620620
# separate observations for the status quo arm.
621621
exp.fetch_data(metrics=[exp.metrics["branin_map_constraint"]])
622-
with self.assertNoLogs(logger=logger, level="WARN"), mock.patch(
623-
"ax.adapter.base._combine_multiple_status_quo_observations",
624-
wraps=_combine_multiple_status_quo_observations,
625-
) as mock_combine:
622+
with (
623+
self.assertNoLogs(logger=logger, level="WARN"),
624+
mock.patch(
625+
"ax.adapter.base._combine_multiple_status_quo_observations",
626+
wraps=_combine_multiple_status_quo_observations,
627+
) as mock_combine,
628+
):
626629
adapter = Adapter(
627630
experiment=exp,
628631
generator=Generator(),
@@ -660,9 +663,12 @@ def test_set_status_quo_with_multiple_observations(self) -> None:
660663
)
661664

662665
# Case 2: Experiment has an optimization config with no map metrics
663-
with mock.patch(
664-
"ax.adapter.base.has_map_metrics", return_value=False
665-
) as mock_extract, self.assertLogs(logger=logger, level="WARN") as mock_logs:
666+
with (
667+
mock.patch(
668+
"ax.adapter.base.has_map_metrics", return_value=False
669+
) as mock_extract,
670+
self.assertLogs(logger=logger, level="WARN") as mock_logs,
671+
):
666672
adapter = Adapter(
667673
experiment=exp,
668674
generator=Generator(),
@@ -1109,9 +1115,10 @@ def mock_predict(
11091115
self.assertTrue(np.allclose(f["m1"], np.ones(3) * 2.0))
11101116

11111117
# Test for error if an observation is dropped.
1112-
with mock.patch.object(
1113-
adapter, "_predict", side_effect=mock_predict
1114-
), self.assertRaisesRegex(ModelError, "Predictions resulted in fewer"):
1118+
with (
1119+
mock.patch.object(adapter, "_predict", side_effect=mock_predict),
1120+
self.assertRaisesRegex(ModelError, "Predictions resulted in fewer"),
1121+
):
11151122
adapter.predict(
11161123
observation_features=[
11171124
ObservationFeatures(parameters={"x": 3.0, "y": 4.0}),

ax/adapter/tests/test_cross_validation.py

Lines changed: 84 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,15 @@ def test_cross_validate_base(self) -> None:
141141
)
142142

143143
# Test LOO - use naive CV path by mocking efficient LOO
144-
with mock.patch(
145-
"ax.adapter.cross_validation._efficient_loo_cross_validate",
146-
side_effect=ValueError("Force fallback to naive CV"),
147-
), mock.patch.object(
148-
self.adapter, "cross_validate", wraps=self.adapter.cross_validate
149-
) as mock_cv:
144+
with (
145+
mock.patch(
146+
"ax.adapter.cross_validation._efficient_loo_cross_validate",
147+
side_effect=ValueError("Force fallback to naive CV"),
148+
),
149+
mock.patch.object(
150+
self.adapter, "cross_validate", wraps=self.adapter.cross_validate
151+
) as mock_cv,
152+
):
150153
result = cross_validate(adapter=self.adapter, folds=-1)
151154
self.assertEqual(len(result), 4)
152155
z = mock_cv.mock_calls
@@ -163,19 +166,23 @@ def test_cross_validate_base(self) -> None:
163166
np.array_equal(sorted(all_test), np.array([2.0, 2.0, 3.0, 4.0]))
164167
)
165168
# Test LOO in transformed space - use naive path by mocking efficient LOO
166-
with mock.patch(
167-
"ax.adapter.cross_validation._efficient_loo_cross_validate",
168-
side_effect=ValueError("Force fallback to naive CV"),
169-
), mock.patch.object(
170-
self.adapter,
171-
"_transform_inputs_for_cv",
172-
wraps=self.adapter._transform_inputs_for_cv,
173-
) as mock_transform_cv, mock.patch.object(
174-
self.adapter,
175-
"_cross_validate",
176-
side_effect=lambda **kwargs: [self.observation_data]
177-
* len(kwargs["cv_test_points"]),
178-
) as mock_cv:
169+
with (
170+
mock.patch(
171+
"ax.adapter.cross_validation._efficient_loo_cross_validate",
172+
side_effect=ValueError("Force fallback to naive CV"),
173+
),
174+
mock.patch.object(
175+
self.adapter,
176+
"_transform_inputs_for_cv",
177+
wraps=self.adapter._transform_inputs_for_cv,
178+
) as mock_transform_cv,
179+
mock.patch.object(
180+
self.adapter,
181+
"_cross_validate",
182+
side_effect=lambda **kwargs: [self.observation_data]
183+
* len(kwargs["cv_test_points"]),
184+
) as mock_cv,
185+
):
179186
result = cross_validate(adapter=self.adapter, folds=-1, untransform=False)
180187
result_predicted_obs_data = [cv_result.predicted for cv_result in result]
181188
self.assertEqual(result_predicted_obs_data, [self.observation_data] * 4)
@@ -246,12 +253,15 @@ def test_selector(obs: Observation) -> bool:
246253

247254
# test observation noise - use naive path by disabling efficient LOO
248255
for untransform in (True, False):
249-
with mock.patch(
250-
"ax.adapter.cross_validation._efficient_loo_cross_validate",
251-
side_effect=ValueError("Force fallback to naive CV"),
252-
), mock.patch.object(
253-
self.adapter, "_cross_validate", wraps=self.adapter._cross_validate
254-
) as mock_cv:
256+
with (
257+
mock.patch(
258+
"ax.adapter.cross_validation._efficient_loo_cross_validate",
259+
side_effect=ValueError("Force fallback to naive CV"),
260+
),
261+
mock.patch.object(
262+
self.adapter, "_cross_validate", wraps=self.adapter._cross_validate
263+
) as mock_cv,
264+
):
255265
result = cross_validate(
256266
adapter=self.adapter,
257267
folds=-1,
@@ -500,9 +510,12 @@ def test_has_good_opt_config_model_fit(self) -> None:
500510
def test_efficient_loo_cv_is_attempted(self) -> None:
501511
"""Test that efficient LOO CV is attempted only when all conditions are met."""
502512
# Setup adapter with a BoTorchGenerator
503-
with mock.patch(
504-
"botorch.cross_validation.efficient_loo_cv"
505-
) as mock_efficient_loo, mock.patch("botorch.cross_validation.ensemble_loo_cv"):
513+
with (
514+
mock.patch(
515+
"botorch.cross_validation.efficient_loo_cv"
516+
) as mock_efficient_loo,
517+
mock.patch("botorch.cross_validation.ensemble_loo_cv"),
518+
):
506519
# Create mock LOO results
507520
# Create a mock posterior
508521
mock_mean = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
@@ -570,11 +583,13 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
570583

571584
# For adapter with aux experiments, directly verify the condition check
572585
# rather than running through the full cross_validate path
573-
with self.subTest(condition="has auxiliary experiments"), mock.patch(
574-
"ax.adapter.cross_validation._efficient_loo_cross_validate"
575-
) as mock_efficient, mock.patch(
576-
"ax.adapter.cross_validation._fold_cross_validate"
577-
) as mock_fold:
586+
with (
587+
self.subTest(condition="has auxiliary experiments"),
588+
mock.patch(
589+
"ax.adapter.cross_validation._efficient_loo_cross_validate"
590+
) as mock_efficient,
591+
mock.patch("ax.adapter.cross_validation._fold_cross_validate") as mock_fold,
592+
):
578593
mock_fold.return_value = []
579594
cross_validate(adapter=adapter_with_aux)
580595
self.assertFalse(
@@ -584,9 +599,12 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
584599

585600
for kwargs, adapter_override, desc in conditions_preventing_efficient_loo:
586601
adapter = adapter_override or self.adapter
587-
with self.subTest(condition=desc), mock.patch(
588-
"ax.adapter.cross_validation._efficient_loo_cross_validate"
589-
) as mock_efficient:
602+
with (
603+
self.subTest(condition=desc),
604+
mock.patch(
605+
"ax.adapter.cross_validation._efficient_loo_cross_validate"
606+
) as mock_efficient,
607+
):
590608
# pyre-ignore[6]: kwargs is properly typed for cross_validate
591609
cross_validate(adapter=adapter, **kwargs)
592610
self.assertFalse(
@@ -596,13 +614,15 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
596614

597615
# Test logger when efficient LOO fails even though all conditions were met
598616
with self.subTest(condition="efficient LOO fails with exception"):
599-
with mock.patch(
600-
"ax.adapter.cross_validation._efficient_loo_cross_validate"
601-
) as mock_efficient, mock.patch(
602-
"ax.adapter.cross_validation._fold_cross_validate"
603-
) as mock_fold, mock.patch(
604-
"ax.adapter.cross_validation.logger"
605-
) as mock_logger:
617+
with (
618+
mock.patch(
619+
"ax.adapter.cross_validation._efficient_loo_cross_validate"
620+
) as mock_efficient,
621+
mock.patch(
622+
"ax.adapter.cross_validation._fold_cross_validate"
623+
) as mock_fold,
624+
mock.patch("ax.adapter.cross_validation.logger") as mock_logger,
625+
):
606626
# Force efficient LOO to fail
607627
mock_efficient.side_effect = ValueError("Test failure reason")
608628
mock_fold.return_value = []
@@ -701,13 +721,16 @@ def _test_efficient_loo_cv_matches_naive(
701721
)
702722

703723
# Run naive CV (by forcing fallback)
704-
with mock.patch(
705-
"ax.adapter.cross_validation._efficient_loo_cross_validate",
706-
side_effect=ValueError("Force fallback to naive CV"),
707-
), mock.patch(
708-
"ax.adapter.cross_validation._fold_cross_validate",
709-
wraps=_fold_cross_validate,
710-
) as mock_naive_cv:
724+
with (
725+
mock.patch(
726+
"ax.adapter.cross_validation._efficient_loo_cross_validate",
727+
side_effect=ValueError("Force fallback to naive CV"),
728+
),
729+
mock.patch(
730+
"ax.adapter.cross_validation._fold_cross_validate",
731+
wraps=_fold_cross_validate,
732+
) as mock_naive_cv,
733+
):
711734
result_naive = cross_validate(
712735
adapter=adapter,
713736
folds=-1,
@@ -719,12 +742,15 @@ def _test_efficient_loo_cv_matches_naive(
719742
self.assertTrue(mock_naive_cv.called, "Naive CV not called")
720743

721744
# Run efficient CV
722-
with mock.patch(
723-
"ax.adapter.cross_validation._efficient_loo_cross_validate",
724-
wraps=_efficient_loo_cross_validate,
725-
) as mock_efficient, mock.patch(
726-
"ax.adapter.cross_validation._fold_cross_validate",
727-
) as mock_naive:
745+
with (
746+
mock.patch(
747+
"ax.adapter.cross_validation._efficient_loo_cross_validate",
748+
wraps=_efficient_loo_cross_validate,
749+
) as mock_efficient,
750+
mock.patch(
751+
"ax.adapter.cross_validation._fold_cross_validate",
752+
) as mock_naive,
753+
):
728754
result_efficient = cross_validate(
729755
adapter=adapter,
730756
folds=-1,
@@ -765,14 +791,12 @@ def sort_key(cv_result: CVResult) -> tuple[float, ...]:
765791
if untransform:
766792
self.assertTrue(
767793
np.all(obs_means > 5.0),
768-
f"untransform=True: expected original space, "
769-
f"got {obs_means}",
794+
f"untransform=True: expected original space, got {obs_means}",
770795
)
771796
else:
772797
self.assertTrue(
773798
np.all(np.abs(obs_means) < 3.0),
774-
f"untransform=False: expected standardized, "
775-
f"got {obs_means}",
799+
f"untransform=False: expected standardized, got {obs_means}",
776800
)
777801

778802
# Compare predictions

ax/adapter/tests/test_hierarchical_search_space.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from ax.adapter.cross_validation import cross_validate
1414
from ax.adapter.registry import Generators
15-
1615
from ax.core.experiment import Experiment
1716
from ax.core.objective import Objective
1817
from ax.core.observation import ObservationFeatures
@@ -23,7 +22,6 @@
2322
ParameterType,
2423
RangeParameter,
2524
)
26-
2725
from ax.core.search_space import SearchSpace
2826
from ax.core.trial import Trial
2927
from ax.metrics.noisy_function import GenericNoisyFunctionMetric

ax/adapter/tests/test_torch_adapter.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,12 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None:
214214
pending_observations = {
215215
"y2": [ObservationFeatures(parameters={"x1": 1.0, "x2": 2.0, "x3": 3.0})]
216216
}
217-
with ExitStack() as es, mock.patch.object(
218-
generator, "gen", return_value=gen_return_value
219-
) as mock_gen:
217+
with (
218+
ExitStack() as es,
219+
mock.patch.object(
220+
generator, "gen", return_value=gen_return_value
221+
) as mock_gen,
222+
):
220223
es.enter_context(
221224
mock.patch.object(
222225
generator, "best_point", return_value=best_point_return_value
@@ -327,9 +330,10 @@ def test_evaluate_acquisition_function(self) -> None:
327330
obsf = ObservationFeatures(parameters={"x1": 1.0, "x2": 2.0})
328331

329332
# Check for value error when optimization config is not set.
330-
with mock.patch.object(
331-
adapter, "_optimization_config", None
332-
), self.assertRaisesRegex(ValueError, "optimization_config"):
333+
with (
334+
mock.patch.object(adapter, "_optimization_config", None),
335+
self.assertRaisesRegex(ValueError, "optimization_config"),
336+
):
333337
adapter.evaluate_acquisition_function(observation_features=[obsf])
334338

335339
mock_acq_val = 5.0
@@ -422,11 +426,14 @@ def test_best_point(self) -> None:
422426
gen_return_value = TorchGenResults(
423427
points=torch.tensor([[1.0]]), weights=torch.tensor([1.0])
424428
)
425-
with mock.patch(
426-
f"{TorchGenerator.__module__}.TorchGenerator.best_point",
427-
return_value=torch.tensor([best_point_value]),
428-
autospec=True,
429-
), mock.patch.object(adapter, "predict", return_value=predict_return_value):
429+
with (
430+
mock.patch(
431+
f"{TorchGenerator.__module__}.TorchGenerator.best_point",
432+
return_value=torch.tensor([best_point_value]),
433+
autospec=True,
434+
),
435+
mock.patch.object(adapter, "predict", return_value=predict_return_value),
436+
):
430437
with mock.patch.object(
431438
adapter.generator, "gen", return_value=gen_return_value
432439
):
@@ -823,14 +830,17 @@ def test_gen_metadata_untransform(self) -> None:
823830
weights=torch.tensor([1.0]),
824831
gen_metadata={Keys.EXPECTED_ACQF_VAL: [1.0], **additional_metadata},
825832
)
826-
with mock.patch.object(
827-
adapter,
828-
"_untransform_objective_thresholds",
829-
wraps=adapter._untransform_objective_thresholds,
830-
) as mock_untransform, mock.patch.object(
831-
generator,
832-
"gen",
833-
return_value=gen_return_value,
833+
with (
834+
mock.patch.object(
835+
adapter,
836+
"_untransform_objective_thresholds",
837+
wraps=adapter._untransform_objective_thresholds,
838+
) as mock_untransform,
839+
mock.patch.object(
840+
generator,
841+
"gen",
842+
return_value=gen_return_value,
843+
),
834844
):
835845
adapter.gen(n=1)
836846
if additional_metadata.get("objective_thresholds", None) is None:

ax/adapter/tests/test_torch_moo_adapter.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -640,15 +640,18 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
640640
torch_device=torch.device("cuda" if cuda else "cpu"),
641641
)
642642
self.assertIn("Cast", adapter.transforms)
643-
with patch.object(
644-
adapter,
645-
"_untransform_objective_thresholds",
646-
wraps=adapter._untransform_objective_thresholds,
647-
) as mock_untransform, patch.object(
648-
adapter.transforms["Cast"],
649-
"untransform_observation_features",
650-
wraps=adapter.transforms["Cast"].untransform_observation_features,
651-
) as wrapped_cast:
643+
with (
644+
patch.object(
645+
adapter,
646+
"_untransform_objective_thresholds",
647+
wraps=adapter._untransform_objective_thresholds,
648+
) as mock_untransform,
649+
patch.object(
650+
adapter.transforms["Cast"],
651+
"untransform_observation_features",
652+
wraps=adapter.transforms["Cast"].untransform_observation_features,
653+
) as wrapped_cast,
654+
):
652655
obj_thresholds = adapter.infer_objective_thresholds(
653656
search_space=exp.search_space,
654657
optimization_config=exp.optimization_config,

ax/adapter/torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,9 @@ def _get_fit_args(
780780
if update_outcomes_and_parameters:
781781
self.outcomes = ordered_outcomes
782782
else:
783-
assert (
784-
ordered_outcomes == self.outcomes
785-
), f"Unexpected ordering of outcomes: {ordered_outcomes} != {self.outcomes}"
783+
assert ordered_outcomes == self.outcomes, (
784+
f"Unexpected ordering of outcomes: {ordered_outcomes} != {self.outcomes}"
785+
)
786786
return datasets, candidate_metadata, search_space_digest
787787

788788
def _fit(

0 commit comments

Comments
 (0)