@@ -382,6 +382,36 @@ def test_optimize_acqf_sequential_notimplemented(self):
382
382
sequential = True ,
383
383
)
384
384
385
+ def test_optimize_acqf_batch_limit (self ) -> None :
386
+ num_restarts = 3
387
+ raw_samples = 5
388
+ dim = 4
389
+ q = 4
390
+ batch_limit = 2
391
+
392
+ options = {"batch_limit" : batch_limit }
393
+ initial_conditions = [
394
+ torch .ones (shape ) for shape in [(1 , 2 , dim ), (2 , 1 , dim ), (1 , dim )]
395
+ ] + [None ]
396
+
397
+ for gen_candidates , ics in zip (
398
+ [gen_candidates_scipy , gen_candidates_torch ], initial_conditions
399
+ ):
400
+ with self .subTest (gen_candidates = gen_candidates , initial_conditions = ics ):
401
+ _ , acq_value_list = optimize_acqf (
402
+ acq_function = SinOneOverXAcqusitionFunction (),
403
+ bounds = torch .stack ([- 1 * torch .ones (dim ), torch .ones (dim )]),
404
+ q = q ,
405
+ num_restarts = num_restarts ,
406
+ raw_samples = raw_samples ,
407
+ options = options ,
408
+ return_best_only = False ,
409
+ gen_candidates = gen_candidates ,
410
+ batch_initial_conditions = ics ,
411
+ )
412
+ expected_shape = (num_restarts ,) if ics is None else (ics .shape [0 ],)
413
+ self .assertEqual (acq_value_list .shape , expected_shape )
414
+
385
415
def test_optimize_acqf_runs_given_batch_initial_conditions (self ):
386
416
num_restarts , raw_samples , dim = 1 , 2 , 3
387
417
0 commit comments