@@ -141,12 +141,15 @@ def test_cross_validate_base(self) -> None:
141141 )
142142
143143 # Test LOO - use naive CV path by mocking efficient LOO
144- with mock .patch (
145- "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
146- side_effect = ValueError ("Force fallback to naive CV" ),
147- ), mock .patch .object (
148- self .adapter , "cross_validate" , wraps = self .adapter .cross_validate
149- ) as mock_cv :
144+ with (
145+ mock .patch (
146+ "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
147+ side_effect = ValueError ("Force fallback to naive CV" ),
148+ ),
149+ mock .patch .object (
150+ self .adapter , "cross_validate" , wraps = self .adapter .cross_validate
151+ ) as mock_cv ,
152+ ):
150153 result = cross_validate (adapter = self .adapter , folds = - 1 )
151154 self .assertEqual (len (result ), 4 )
152155 z = mock_cv .mock_calls
@@ -163,19 +166,23 @@ def test_cross_validate_base(self) -> None:
163166 np .array_equal (sorted (all_test ), np .array ([2.0 , 2.0 , 3.0 , 4.0 ]))
164167 )
165168 # Test LOO in transformed space - use naive path by mocking efficient LOO
166- with mock .patch (
167- "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
168- side_effect = ValueError ("Force fallback to naive CV" ),
169- ), mock .patch .object (
170- self .adapter ,
171- "_transform_inputs_for_cv" ,
172- wraps = self .adapter ._transform_inputs_for_cv ,
173- ) as mock_transform_cv , mock .patch .object (
174- self .adapter ,
175- "_cross_validate" ,
176- side_effect = lambda ** kwargs : [self .observation_data ]
177- * len (kwargs ["cv_test_points" ]),
178- ) as mock_cv :
169+ with (
170+ mock .patch (
171+ "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
172+ side_effect = ValueError ("Force fallback to naive CV" ),
173+ ),
174+ mock .patch .object (
175+ self .adapter ,
176+ "_transform_inputs_for_cv" ,
177+ wraps = self .adapter ._transform_inputs_for_cv ,
178+ ) as mock_transform_cv ,
179+ mock .patch .object (
180+ self .adapter ,
181+ "_cross_validate" ,
182+ side_effect = lambda ** kwargs : [self .observation_data ]
183+ * len (kwargs ["cv_test_points" ]),
184+ ) as mock_cv ,
185+ ):
179186 result = cross_validate (adapter = self .adapter , folds = - 1 , untransform = False )
180187 result_predicted_obs_data = [cv_result .predicted for cv_result in result ]
181188 self .assertEqual (result_predicted_obs_data , [self .observation_data ] * 4 )
@@ -246,12 +253,15 @@ def test_selector(obs: Observation) -> bool:
246253
247254 # test observation noise - use naive path by disabling efficient LOO
248255 for untransform in (True , False ):
249- with mock .patch (
250- "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
251- side_effect = ValueError ("Force fallback to naive CV" ),
252- ), mock .patch .object (
253- self .adapter , "_cross_validate" , wraps = self .adapter ._cross_validate
254- ) as mock_cv :
256+ with (
257+ mock .patch (
258+ "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
259+ side_effect = ValueError ("Force fallback to naive CV" ),
260+ ),
261+ mock .patch .object (
262+ self .adapter , "_cross_validate" , wraps = self .adapter ._cross_validate
263+ ) as mock_cv ,
264+ ):
255265 result = cross_validate (
256266 adapter = self .adapter ,
257267 folds = - 1 ,
@@ -500,9 +510,12 @@ def test_has_good_opt_config_model_fit(self) -> None:
500510 def test_efficient_loo_cv_is_attempted (self ) -> None :
501511 """Test that efficient LOO CV is attempted only when all conditions are met."""
502512 # Setup adapter with a BoTorchGenerator
503- with mock .patch (
504- "botorch.cross_validation.efficient_loo_cv"
505- ) as mock_efficient_loo , mock .patch ("botorch.cross_validation.ensemble_loo_cv" ):
513+ with (
514+ mock .patch (
515+ "botorch.cross_validation.efficient_loo_cv"
516+ ) as mock_efficient_loo ,
517+ mock .patch ("botorch.cross_validation.ensemble_loo_cv" ),
518+ ):
506519 # Create mock LOO results
507520 # Create a mock posterior
508521 mock_mean = torch .tensor ([[1.0 ], [2.0 ], [3.0 ], [4.0 ]])
@@ -570,11 +583,13 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
570583
571584 # For adapter with aux experiments, directly verify the condition check
572585 # rather than running through the full cross_validate path
573- with self .subTest (condition = "has auxiliary experiments" ), mock .patch (
574- "ax.adapter.cross_validation._efficient_loo_cross_validate"
575- ) as mock_efficient , mock .patch (
576- "ax.adapter.cross_validation._fold_cross_validate"
577- ) as mock_fold :
586+ with (
587+ self .subTest (condition = "has auxiliary experiments" ),
588+ mock .patch (
589+ "ax.adapter.cross_validation._efficient_loo_cross_validate"
590+ ) as mock_efficient ,
591+ mock .patch ("ax.adapter.cross_validation._fold_cross_validate" ) as mock_fold ,
592+ ):
578593 mock_fold .return_value = []
579594 cross_validate (adapter = adapter_with_aux )
580595 self .assertFalse (
@@ -584,9 +599,12 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
584599
585600 for kwargs , adapter_override , desc in conditions_preventing_efficient_loo :
586601 adapter = adapter_override or self .adapter
587- with self .subTest (condition = desc ), mock .patch (
588- "ax.adapter.cross_validation._efficient_loo_cross_validate"
589- ) as mock_efficient :
602+ with (
603+ self .subTest (condition = desc ),
604+ mock .patch (
605+ "ax.adapter.cross_validation._efficient_loo_cross_validate"
606+ ) as mock_efficient ,
607+ ):
590608 # pyre-ignore[6]: kwargs is properly typed for cross_validate
591609 cross_validate (adapter = adapter , ** kwargs )
592610 self .assertFalse (
@@ -596,13 +614,15 @@ def _fold_gen(td: ExperimentData) -> Iterable[CVData]:
596614
597615 # Test logger when efficient LOO fails even though all conditions were met
598616 with self .subTest (condition = "efficient LOO fails with exception" ):
599- with mock .patch (
600- "ax.adapter.cross_validation._efficient_loo_cross_validate"
601- ) as mock_efficient , mock .patch (
602- "ax.adapter.cross_validation._fold_cross_validate"
603- ) as mock_fold , mock .patch (
604- "ax.adapter.cross_validation.logger"
605- ) as mock_logger :
617+ with (
618+ mock .patch (
619+ "ax.adapter.cross_validation._efficient_loo_cross_validate"
620+ ) as mock_efficient ,
621+ mock .patch (
622+ "ax.adapter.cross_validation._fold_cross_validate"
623+ ) as mock_fold ,
624+ mock .patch ("ax.adapter.cross_validation.logger" ) as mock_logger ,
625+ ):
606626 # Force efficient LOO to fail
607627 mock_efficient .side_effect = ValueError ("Test failure reason" )
608628 mock_fold .return_value = []
@@ -701,13 +721,16 @@ def _test_efficient_loo_cv_matches_naive(
701721 )
702722
703723 # Run naive CV (by forcing fallback)
704- with mock .patch (
705- "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
706- side_effect = ValueError ("Force fallback to naive CV" ),
707- ), mock .patch (
708- "ax.adapter.cross_validation._fold_cross_validate" ,
709- wraps = _fold_cross_validate ,
710- ) as mock_naive_cv :
724+ with (
725+ mock .patch (
726+ "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
727+ side_effect = ValueError ("Force fallback to naive CV" ),
728+ ),
729+ mock .patch (
730+ "ax.adapter.cross_validation._fold_cross_validate" ,
731+ wraps = _fold_cross_validate ,
732+ ) as mock_naive_cv ,
733+ ):
711734 result_naive = cross_validate (
712735 adapter = adapter ,
713736 folds = - 1 ,
@@ -719,12 +742,15 @@ def _test_efficient_loo_cv_matches_naive(
719742 self .assertTrue (mock_naive_cv .called , "Naive CV not called" )
720743
721744 # Run efficient CV
722- with mock .patch (
723- "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
724- wraps = _efficient_loo_cross_validate ,
725- ) as mock_efficient , mock .patch (
726- "ax.adapter.cross_validation._fold_cross_validate" ,
727- ) as mock_naive :
745+ with (
746+ mock .patch (
747+ "ax.adapter.cross_validation._efficient_loo_cross_validate" ,
748+ wraps = _efficient_loo_cross_validate ,
749+ ) as mock_efficient ,
750+ mock .patch (
751+ "ax.adapter.cross_validation._fold_cross_validate" ,
752+ ) as mock_naive ,
753+ ):
728754 result_efficient = cross_validate (
729755 adapter = adapter ,
730756 folds = - 1 ,
@@ -765,14 +791,12 @@ def sort_key(cv_result: CVResult) -> tuple[float, ...]:
765791 if untransform :
766792 self .assertTrue (
767793 np .all (obs_means > 5.0 ),
768- f"untransform=True: expected original space, "
769- f"got { obs_means } " ,
794+ f"untransform=True: expected original space, got { obs_means } " ,
770795 )
771796 else :
772797 self .assertTrue (
773798 np .all (np .abs (obs_means ) < 3.0 ),
774- f"untransform=False: expected standardized, "
775- f"got { obs_means } " ,
799+ f"untransform=False: expected standardized, got { obs_means } " ,
776800 )
777801
778802 # Compare predictions
0 commit comments