@@ -205,9 +205,12 @@ def test_TorchAdapter(self, device: torch.device | None = None) -> None:
205205 pending_observations = {
206206 "y2" : [ObservationFeatures (parameters = {"x1" : 1.0 , "x2" : 2.0 , "x3" : 3.0 })]
207207 }
208- with ExitStack () as es , mock .patch .object (
209- generator , "gen" , return_value = gen_return_value
210- ) as mock_gen :
208+ with (
209+ ExitStack () as es ,
210+ mock .patch .object (
211+ generator , "gen" , return_value = gen_return_value
212+ ) as mock_gen ,
213+ ):
211214 es .enter_context (
212215 mock .patch .object (
213216 generator , "best_point" , return_value = best_point_return_value
@@ -318,9 +321,10 @@ def test_evaluate_acquisition_function(self) -> None:
318321 obsf = ObservationFeatures (parameters = {"x1" : 1.0 , "x2" : 2.0 })
319322
320323 # Check for value error when optimization config is not set.
321- with mock .patch .object (
322- adapter , "_optimization_config" , None
323- ), self .assertRaisesRegex (ValueError , "optimization_config" ):
324+ with (
325+ mock .patch .object (adapter , "_optimization_config" , None ),
326+ self .assertRaisesRegex (ValueError , "optimization_config" ),
327+ ):
324328 adapter .evaluate_acquisition_function (observation_features = [obsf ])
325329
326330 mock_acq_val = 5.0
@@ -413,11 +417,14 @@ def test_best_point(self) -> None:
413417 gen_return_value = TorchGenResults (
414418 points = torch .tensor ([[1.0 ]]), weights = torch .tensor ([1.0 ])
415419 )
416- with mock .patch (
417- f"{ TorchGenerator .__module__ } .TorchGenerator.best_point" ,
418- return_value = torch .tensor ([best_point_value ]),
419- autospec = True ,
420- ), mock .patch .object (adapter , "predict" , return_value = predict_return_value ):
420+ with (
421+ mock .patch (
422+ f"{ TorchGenerator .__module__ } .TorchGenerator.best_point" ,
423+ return_value = torch .tensor ([best_point_value ]),
424+ autospec = True ,
425+ ),
426+ mock .patch .object (adapter , "predict" , return_value = predict_return_value ),
427+ ):
421428 with mock .patch .object (
422429 adapter .generator , "gen" , return_value = gen_return_value
423430 ):
@@ -814,14 +821,17 @@ def test_gen_metadata_untransform(self) -> None:
814821 weights = torch .tensor ([1.0 ]),
815822 gen_metadata = {Keys .EXPECTED_ACQF_VAL : [1.0 ], ** additional_metadata },
816823 )
817- with mock .patch .object (
818- adapter ,
819- "_untransform_objective_thresholds" ,
820- wraps = adapter ._untransform_objective_thresholds ,
821- ) as mock_untransform , mock .patch .object (
822- generator ,
823- "gen" ,
824- return_value = gen_return_value ,
824+ with (
825+ mock .patch .object (
826+ adapter ,
827+ "_untransform_objective_thresholds" ,
828+ wraps = adapter ._untransform_objective_thresholds ,
829+ ) as mock_untransform ,
830+ mock .patch .object (
831+ generator ,
832+ "gen" ,
833+ return_value = gen_return_value ,
834+ ),
825835 ):
826836 adapter .gen (n = 1 )
827837 if additional_metadata .get ("objective_thresholds" , None ) is None :
0 commit comments