@@ -390,7 +390,6 @@ def test_construct_inputs_qEI(self):
390390 self .assertTrue (torch .equal (kwargs ["objective" ].weights , objective .weights ))
391391 self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
392392 self .assertIsNone (kwargs ["sampler" ])
393- self .assertIsNone (kwargs ["constraints" ])
394393 self .assertIsInstance (kwargs ["eta" ], float )
395394 self .assertTrue (kwargs ["eta" ] < 1 )
396395 multi_Y = torch .cat ([d .Y () for d in self .blockX_multiY .values ()], dim = - 1 )
@@ -406,6 +405,20 @@ def test_construct_inputs_qEI(self):
406405 best_f = best_f_expected ,
407406 )
408407 self .assertEqual (kwargs ["best_f" ], best_f_expected )
408+ # test passing constraints
409+ outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
410+ constraints = get_outcome_constraint_transforms (
411+ outcome_constraints = outcome_constraints
412+ )
413+ kwargs = c (
414+ model = mock_model ,
415+ training_data = self .blockX_multiY ,
416+ objective = objective ,
417+ X_pending = X_pending ,
418+ best_f = best_f_expected ,
419+ constraints = constraints ,
420+ )
421+ self .assertIs (kwargs ["constraints" ], constraints )
409422
410423 # testing qLogEI input constructor
411424 log_constructor = get_acqf_input_constructor (qLogExpectedImprovement )
@@ -415,6 +428,7 @@ def test_construct_inputs_qEI(self):
415428 objective = objective ,
416429 X_pending = X_pending ,
417430 best_f = best_f_expected ,
431+ constraints = constraints ,
418432 )
419433 # includes strict superset of kwargs tested above
420434 self .assertTrue (kwargs .items () <= log_kwargs .items ())
@@ -423,6 +437,7 @@ def test_construct_inputs_qEI(self):
423437 self .assertEqual (log_kwargs ["tau_max" ], TAU_MAX )
424438 self .assertTrue ("tau_relu" in log_kwargs )
425439 self .assertEqual (log_kwargs ["tau_relu" ], TAU_RELU )
440+ self .assertIs (log_kwargs ["constraints" ], constraints )
426441
427442 def test_construct_inputs_qNEI (self ):
428443 c = get_acqf_input_constructor (qNoisyExpectedImprovement )
@@ -441,29 +456,36 @@ def test_construct_inputs_qNEI(self):
441456 with self .assertRaisesRegex (ValueError , "Field `X` must be shared" ):
442457 c (model = mock_model , training_data = self .multiX_multiY )
443458 X_baseline = torch .rand (2 , 2 )
459+ outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
460+ constraints = get_outcome_constraint_transforms (
461+ outcome_constraints = outcome_constraints
462+ )
444463 kwargs = c (
445464 model = mock_model ,
446465 training_data = self .blockX_blockY ,
447466 X_baseline = X_baseline ,
448467 prune_baseline = False ,
468+ constraints = constraints ,
449469 )
450470 self .assertEqual (kwargs ["model" ], mock_model )
451471 self .assertIsNone (kwargs ["objective" ])
452472 self .assertIsNone (kwargs ["X_pending" ])
453473 self .assertIsNone (kwargs ["sampler" ])
454474 self .assertFalse (kwargs ["prune_baseline" ])
455475 self .assertTrue (torch .equal (kwargs ["X_baseline" ], X_baseline ))
456- self .assertIsNone (kwargs ["constraints" ])
457476 self .assertIsInstance (kwargs ["eta" ], float )
458477 self .assertTrue (kwargs ["eta" ] < 1 )
478+ self .assertIs (kwargs ["constraints" ], constraints )
459479
460480 # testing qLogNEI input constructor
461481 log_constructor = get_acqf_input_constructor (qLogNoisyExpectedImprovement )
482+
462483 log_kwargs = log_constructor (
463484 model = mock_model ,
464485 training_data = self .blockX_blockY ,
465486 X_baseline = X_baseline ,
466487 prune_baseline = False ,
488+ constraints = constraints ,
467489 )
468490 # includes strict superset of kwargs tested above
469491 self .assertTrue (kwargs .items () <= log_kwargs .items ())
@@ -472,6 +494,7 @@ def test_construct_inputs_qNEI(self):
472494 self .assertEqual (log_kwargs ["tau_max" ], TAU_MAX )
473495 self .assertTrue ("tau_relu" in log_kwargs )
474496 self .assertEqual (log_kwargs ["tau_relu" ], TAU_RELU )
497+ self .assertIs (log_kwargs ["constraints" ], constraints )
475498
476499 def test_construct_inputs_qPI (self ):
477500 c = get_acqf_input_constructor (qProbabilityOfImprovement )
@@ -499,23 +522,28 @@ def test_construct_inputs_qPI(self):
499522 self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
500523 self .assertIsNone (kwargs ["sampler" ])
501524 self .assertEqual (kwargs ["tau" ], 1e-2 )
502- self .assertIsNone (kwargs ["constraints" ])
503525 self .assertIsInstance (kwargs ["eta" ], float )
504526 self .assertTrue (kwargs ["eta" ] < 1 )
505527 multi_Y = torch .cat ([d .Y () for d in self .blockX_multiY .values ()], dim = - 1 )
506528 best_f_expected = objective (multi_Y ).max ()
507529 self .assertEqual (kwargs ["best_f" ], best_f_expected )
508530 # Check explicitly specifying `best_f`.
509531 best_f_expected = best_f_expected - 1 # Random value.
532+ outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
533+ constraints = get_outcome_constraint_transforms (
534+ outcome_constraints = outcome_constraints
535+ )
510536 kwargs = c (
511537 model = mock_model ,
512538 training_data = self .blockX_multiY ,
513539 objective = objective ,
514540 X_pending = X_pending ,
515541 tau = 1e-2 ,
516542 best_f = best_f_expected ,
543+ constraints = constraints ,
517544 )
518545 self .assertEqual (kwargs ["best_f" ], best_f_expected )
546+ self .assertIs (kwargs ["constraints" ], constraints )
519547
520548 def test_construct_inputs_qUCB (self ):
521549 c = get_acqf_input_constructor (qUpperConfidenceBound )
@@ -564,7 +592,7 @@ def test_construct_inputs_EHVI(self):
564592 model = mock_model ,
565593 training_data = self .blockX_blockY ,
566594 objective_thresholds = objective_thresholds ,
567- outcome_constraints = mock .Mock (),
595+ constraints = mock .Mock (),
568596 )
569597
570598 # test with Y_pmean supplied explicitly
@@ -702,13 +730,16 @@ def test_construct_inputs_qEHVI(self):
702730 weights = torch .rand (2 )
703731 obj = WeightedMCMultiOutputObjective (weights = weights )
704732 outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
733+ constraints = get_outcome_constraint_transforms (
734+ outcome_constraints = outcome_constraints
735+ )
705736 X_pending = torch .rand (1 , 2 )
706737 kwargs = c (
707738 model = mm ,
708739 training_data = self .blockX_blockY ,
709740 objective_thresholds = objective_thresholds ,
710741 objective = obj ,
711- outcome_constraints = outcome_constraints ,
742+ constraints = constraints ,
712743 X_pending = X_pending ,
713744 alpha = 0.05 ,
714745 eta = 1e-2 ,
@@ -723,11 +754,7 @@ def test_construct_inputs_qEHVI(self):
723754 Y_expected = mean [:1 ] * weights
724755 self .assertTrue (torch .equal (partitioning ._neg_Y , - Y_expected ))
725756 self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
726- cons_tfs = kwargs ["constraints" ]
727- self .assertEqual (len (cons_tfs ), 1 )
728- cons_eval = cons_tfs [0 ](mean )
729- cons_eval_expected = torch .tensor ([- 0.25 , 0.5 ])
730- self .assertTrue (torch .equal (cons_eval , cons_eval_expected ))
757+ self .assertIs (kwargs ["constraints" ], constraints )
731758 self .assertEqual (kwargs ["eta" ], 1e-2 )
732759
733760 # Test check for block designs
@@ -737,7 +764,7 @@ def test_construct_inputs_qEHVI(self):
737764 training_data = self .multiX_multiY ,
738765 objective_thresholds = objective_thresholds ,
739766 objective = obj ,
740- outcome_constraints = outcome_constraints ,
767+ constraints = constraints ,
741768 X_pending = X_pending ,
742769 alpha = 0.05 ,
743770 eta = 1e-2 ,
@@ -798,6 +825,9 @@ def test_construct_inputs_qNEHVI(self):
798825 X_baseline = torch .rand (2 , 2 )
799826 sampler = IIDNormalSampler (sample_shape = torch .Size ([4 ]))
800827 outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
828+ constraints = get_outcome_constraint_transforms (
829+ outcome_constraints = outcome_constraints
830+ )
801831 X_pending = torch .rand (1 , 2 )
802832 kwargs = c (
803833 model = mock_model ,
@@ -806,7 +836,7 @@ def test_construct_inputs_qNEHVI(self):
806836 objective = objective ,
807837 X_baseline = X_baseline ,
808838 sampler = sampler ,
809- outcome_constraints = outcome_constraints ,
839+ constraints = constraints ,
810840 X_pending = X_pending ,
811841 eta = 1e-2 ,
812842 prune_baseline = True ,
@@ -823,11 +853,7 @@ def test_construct_inputs_qNEHVI(self):
823853 self .assertIsInstance (sampler_ , IIDNormalSampler )
824854 self .assertEqual (sampler_ .sample_shape , torch .Size ([4 ]))
825855 self .assertEqual (kwargs ["objective" ], objective )
826- cons_tfs_expected = get_outcome_constraint_transforms (outcome_constraints )
827- cons_tfs = kwargs ["constraints" ]
828- self .assertEqual (len (cons_tfs ), 1 )
829- test_Y = torch .rand (1 , 2 )
830- self .assertTrue (torch .equal (cons_tfs [0 ](test_Y ), cons_tfs_expected [0 ](test_Y )))
856+ self .assertIs (kwargs ["constraints" ], constraints )
831857 self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
832858 self .assertEqual (kwargs ["eta" ], 1e-2 )
833859 self .assertTrue (kwargs ["prune_baseline" ])
@@ -844,7 +870,7 @@ def test_construct_inputs_qNEHVI(self):
844870 training_data = self .blockX_blockY ,
845871 objective_thresholds = objective_thresholds ,
846872 objective = MultiOutputExpectation (n_w = 3 ),
847- outcome_constraints = outcome_constraints ,
873+ constraints = constraints ,
848874 )
849875 for use_preprocessing in (True , False ):
850876 obj = MultiOutputExpectation (
0 commit comments