@@ -721,33 +721,67 @@ def setUp(self):
721721 super ().setUp ()
722722
723723 def test_q_noisy_expected_hypervolume_improvement (self ):
724- for dtype , m in product (
725- (torch .float , torch .double ),
726- (1 , 2 , 3 ),
727- ):
728- with self .subTest (dtype = dtype , m = m ):
729- self ._test_q_noisy_expected_hypervolume_improvement (
730- qNoisyExpectedHypervolumeImprovement , dtype , m
731- )
724+ for dtype in (torch .float , torch .double ):
725+ self ._test_q_noisy_expected_hypervolume_improvement_m1 (
726+ qNoisyExpectedHypervolumeImprovement , dtype
727+ )
728+ for m in (2 , 3 ):
729+ with self .subTest (dtype = dtype , m = m ):
730+ self ._test_q_noisy_expected_hypervolume_improvement (
731+ qNoisyExpectedHypervolumeImprovement , dtype , m
732+ )
732733
733734 def test_q_log_noisy_expected_hypervolume_improvement (self ):
734- for dtype , m in product (
735- (torch .float , torch .double ),
736- (1 , 2 , 3 ),
735+ for dtype in (torch .float , torch .double ):
736+ self ._test_q_noisy_expected_hypervolume_improvement_m1 (
737+ qLogNoisyExpectedHypervolumeImprovement , dtype
738+ )
739+ for m in (2 , 3 ):
740+ with self .subTest (dtype = dtype , m = m ):
741+ self ._test_q_noisy_expected_hypervolume_improvement (
742+ qLogNoisyExpectedHypervolumeImprovement , dtype , m
743+ )
744+
745+ def _test_q_noisy_expected_hypervolume_improvement_m1 (
746+ self , acqf_class : Type [AcquisitionFunction ], dtype : torch .dtype
747+ ):
748+ # special case test for m = 1.
749+ (
750+ ref_point ,
751+ X ,
752+ X_baseline ,
753+ mm ,
754+ sampler ,
755+ samples ,
756+ baseline_samples ,
757+ tkwargs ,
758+ ) = self ._setup_qnehvi_test (dtype = dtype , m = 1 )
759+ # test error is raised if m == 1
760+ with self .assertRaisesRegex (
761+ ValueError ,
762+ "NoisyExpectedHypervolumeMixin supports m>=2 outcomes " ,
737763 ):
738- with self .subTest (dtype = dtype , m = m ):
739- self ._test_q_noisy_expected_hypervolume_improvement (
740- qLogNoisyExpectedHypervolumeImprovement , dtype , m
741- )
764+ acqf_class (
765+ model = mm ,
766+ ref_point = ref_point ,
767+ X_baseline = X_baseline ,
768+ sampler = sampler ,
769+ cache_root = False ,
770+ )
742771
743772 def _test_q_noisy_expected_hypervolume_improvement (
744773 self , acqf_class : Type [AcquisitionFunction ], dtype : torch .dtype , m : int
745- ):
774+ ) -> None :
775+ self ._test_qnehvi_base (acqf_class , dtype , m )
776+ # test with and without cached box decomposition (CBD)
777+ self ._test_qnehvi_with_CBD (acqf_class , dtype , m )
778+ self ._test_qnehvi_without_CBD (acqf_class , dtype , m )
779+
780+ def _setup_qnehvi_test (self , dtype : torch .dtype , m : int ) -> None :
746781 tkwargs = {"device" : self .device }
747782 tkwargs ["dtype" ] = dtype
748783 ref_point = self .ref_point [:m ]
749784 Y = self .Y_raw [:, :m ].to (** tkwargs )
750- pareto_Y = self .pareto_Y_raw [:, :m ].to (** tkwargs )
751785 X_baseline = torch .rand (Y .shape [0 ], 1 , ** tkwargs )
752786 # the event shape is `b x q + r x m` = 1 x 1 x 2
753787 baseline_samples = Y
@@ -759,22 +793,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
759793 X = torch .zeros (1 , 1 , ** tkwargs )
760794 # basic test
761795 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
796+ return ref_point , X , X_baseline , mm , sampler , samples , baseline_samples , tkwargs
762797
763- # test error is raised if m == 1
764- if m == 1 :
765- with self .assertRaisesRegex (
766- ValueError ,
767- "NoisyExpectedHypervolumeMixin supports m>=2 outcomes " ,
768- ):
769- acqf = acqf_class (
770- model = mm ,
771- ref_point = ref_point ,
772- X_baseline = X_baseline ,
773- sampler = sampler ,
774- cache_root = False ,
775- )
776- return
777-
798+ def _test_qnehvi_base (
799+ self , acqf_class : Type [AcquisitionFunction ], dtype : torch .dtype , m : int
800+ ) -> None :
801+ (
802+ ref_point ,
803+ X ,
804+ X_baseline ,
805+ mm ,
806+ sampler ,
807+ samples ,
808+ baseline_samples ,
809+ tkwargs ,
810+ ) = self ._setup_qnehvi_test (dtype = dtype , m = m )
778811 acqf = acqf_class (
779812 model = mm ,
780813 ref_point = ref_point ,
@@ -934,6 +967,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
934967 self .assertEqual (list (b .shape ), [1 , 1 , m ])
935968 self .assertEqual (list (b .shape ), [1 , 1 , m ])
936969
970+ def _test_qnehvi_with_CBD (
971+ self , acqf_class : Type [AcquisitionFunction ], dtype : torch .dtype , m : int
972+ ) -> None :
973+ (
974+ ref_point ,
975+ X ,
976+ X_baseline ,
977+ mm ,
978+ sampler ,
979+ samples ,
980+ baseline_samples ,
981+ tkwargs ,
982+ ) = self ._setup_qnehvi_test (dtype = dtype , m = m )
983+ pareto_Y = self .pareto_Y_raw [:, :m ].to (** tkwargs )
984+
937985 # test no baseline points
938986 ref_point2 = [15.0 , 14.0 , 16.0 ][:m ]
939987 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
@@ -1146,6 +1194,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
11461194 self .assertTrue (torch .equal (acqf_pareto_Y [:- 2 ], expected_pareto_Y ))
11471195 self .assertTrue (torch .equal (acqf_pareto_Y [- 2 :], expected_new_Y2 ))
11481196
1197+ def _test_qnehvi_without_CBD (
1198+ self , acqf_class : Type [AcquisitionFunction ], dtype : torch .dtype , m : int
1199+ ) -> None :
1200+ tkwargs = {"device" : self .device }
1201+ tkwargs ["dtype" ] = dtype
1202+ ref_point = self .ref_point [:m ]
1203+ Y = self .Y_raw [:, :m ].to (** tkwargs )
1204+ pareto_Y = self .pareto_Y_raw [:, :m ].to (** tkwargs )
1205+ X_baseline = torch .rand (Y .shape [0 ], 1 , ** tkwargs )
1206+ # the event shape is `b x q + r x m` = 1 x 1 x 2
1207+ baseline_samples = Y
1208+ mm = MockModel (MockPosterior (samples = baseline_samples ))
1209+
1210+ X_pending = torch .rand (1 , 1 , dtype = dtype , device = self .device )
1211+
11491212 # test qNEHVI without CBD
11501213 mm ._posterior ._samples = baseline_samples
11511214 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
@@ -1158,6 +1221,7 @@ def _test_q_noisy_expected_hypervolume_improvement(
11581221 cache_pending = False ,
11591222 cache_root = False ,
11601223 )
1224+ new_Y = torch .tensor ([[0.5 , 3.0 , 0.5 ][:m ]], dtype = dtype , device = self .device )
11611225 mm ._posterior ._samples = torch .cat (
11621226 [
11631227 baseline_samples ,
@@ -1168,15 +1232,25 @@ def _test_q_noisy_expected_hypervolume_improvement(
11681232 acqf .set_X_pending (X_pending10 )
11691233 self .assertTrue (torch .equal (acqf .X_pending , X_pending10 ))
11701234 acqf_pareto_Y = acqf .partitioning .pareto_Y [0 ]
1235+ expected_pareto_Y = pareto_Y if m == 2 else pareto_Y .cpu ()
11711236 self .assertTrue (torch .equal (acqf_pareto_Y , expected_pareto_Y ))
11721237 acqf .set_X_pending (X_pending )
1238+ # test incremental nehvi in forward
1239+ new_Y2 = torch .cat (
1240+ [
1241+ new_Y ,
1242+ torch .tensor ([[0.25 , 9.5 , 1.5 ][:m ]], dtype = dtype , device = self .device ),
1243+ ],
1244+ dim = 0 ,
1245+ )
11731246 mm ._posterior ._samples = torch .cat (
11741247 [
11751248 baseline_samples ,
11761249 new_Y2 ,
11771250 ]
11781251 ).unsqueeze (0 )
11791252 with torch .no_grad ():
1253+ X_test = torch .rand (1 , 1 , dtype = dtype , device = self .device )
11801254 val = evaluate (acqf , X_test )
11811255 bd = DominatedPartitioning (
11821256 ref_point = torch .tensor (ref_point ).to (** tkwargs ), Y = pareto_Y
@@ -1212,6 +1286,10 @@ def _test_q_noisy_expected_hypervolume_improvement(
12121286 # test X_pending is not None on __init__
12131287 mm ._posterior ._samples = torch .zeros (1 , 5 , m , ** tkwargs )
12141288 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
1289+ # add another point
1290+ X_pending2 = torch .cat (
1291+ [X_pending , torch .rand (1 , 1 , dtype = dtype , device = self .device )], dim = 0
1292+ )
12151293 acqf = acqf_class (
12161294 model = mm ,
12171295 ref_point = ref_point ,
0 commit comments