Skip to content

Commit 85ccd2d

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Break up qNEHVI test for readibility (#2076)
Summary: Pull Request resolved: #2076 This commit splits the `qNEHVI` test into three parts: 1) tests for the base functionality, as well as 2) with, and 3) without the cached box decomposition (CBD). Reviewed By: esantorella Differential Revision: D50808514 fbshipit-source-id: 84247c61c6a2c569cc4c1d7618ee887e68af7557
1 parent 01b2503 commit 85ccd2d

File tree

1 file changed

+110
-32
lines changed

1 file changed

+110
-32
lines changed

test/acquisition/multi_objective/test_monte_carlo.py

Lines changed: 110 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -721,33 +721,67 @@ def setUp(self):
721721
super().setUp()
722722

723723
def test_q_noisy_expected_hypervolume_improvement(self):
724-
for dtype, m in product(
725-
(torch.float, torch.double),
726-
(1, 2, 3),
727-
):
728-
with self.subTest(dtype=dtype, m=m):
729-
self._test_q_noisy_expected_hypervolume_improvement(
730-
qNoisyExpectedHypervolumeImprovement, dtype, m
731-
)
724+
for dtype in (torch.float, torch.double):
725+
self._test_q_noisy_expected_hypervolume_improvement_m1(
726+
qNoisyExpectedHypervolumeImprovement, dtype
727+
)
728+
for m in (2, 3):
729+
with self.subTest(dtype=dtype, m=m):
730+
self._test_q_noisy_expected_hypervolume_improvement(
731+
qNoisyExpectedHypervolumeImprovement, dtype, m
732+
)
732733

733734
def test_q_log_noisy_expected_hypervolume_improvement(self):
734-
for dtype, m in product(
735-
(torch.float, torch.double),
736-
(1, 2, 3),
735+
for dtype in (torch.float, torch.double):
736+
self._test_q_noisy_expected_hypervolume_improvement_m1(
737+
qLogNoisyExpectedHypervolumeImprovement, dtype
738+
)
739+
for m in (2, 3):
740+
with self.subTest(dtype=dtype, m=m):
741+
self._test_q_noisy_expected_hypervolume_improvement(
742+
qLogNoisyExpectedHypervolumeImprovement, dtype, m
743+
)
744+
745+
def _test_q_noisy_expected_hypervolume_improvement_m1(
746+
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype
747+
):
748+
# special case test for m = 1.
749+
(
750+
ref_point,
751+
X,
752+
X_baseline,
753+
mm,
754+
sampler,
755+
samples,
756+
baseline_samples,
757+
tkwargs,
758+
) = self._setup_qnehvi_test(dtype=dtype, m=1)
759+
# test error is raised if m == 1
760+
with self.assertRaisesRegex(
761+
ValueError,
762+
"NoisyExpectedHypervolumeMixin supports m>=2 outcomes ",
737763
):
738-
with self.subTest(dtype=dtype, m=m):
739-
self._test_q_noisy_expected_hypervolume_improvement(
740-
qLogNoisyExpectedHypervolumeImprovement, dtype, m
741-
)
764+
acqf_class(
765+
model=mm,
766+
ref_point=ref_point,
767+
X_baseline=X_baseline,
768+
sampler=sampler,
769+
cache_root=False,
770+
)
742771

743772
def _test_q_noisy_expected_hypervolume_improvement(
744773
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
745-
):
774+
) -> None:
775+
self._test_qnehvi_base(acqf_class, dtype, m)
776+
# test with and without cached box decomposition (CBD)
777+
self._test_qnehvi_with_CBD(acqf_class, dtype, m)
778+
self._test_qnehvi_without_CBD(acqf_class, dtype, m)
779+
780+
def _setup_qnehvi_test(self, dtype: torch.dtype, m: int) -> None:
746781
tkwargs = {"device": self.device}
747782
tkwargs["dtype"] = dtype
748783
ref_point = self.ref_point[:m]
749784
Y = self.Y_raw[:, :m].to(**tkwargs)
750-
pareto_Y = self.pareto_Y_raw[:, :m].to(**tkwargs)
751785
X_baseline = torch.rand(Y.shape[0], 1, **tkwargs)
752786
# the event shape is `b x q + r x m` = 1 x 1 x 2
753787
baseline_samples = Y
@@ -759,22 +793,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
759793
X = torch.zeros(1, 1, **tkwargs)
760794
# basic test
761795
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
796+
return ref_point, X, X_baseline, mm, sampler, samples, baseline_samples, tkwargs
762797

763-
# test error is raised if m == 1
764-
if m == 1:
765-
with self.assertRaisesRegex(
766-
ValueError,
767-
"NoisyExpectedHypervolumeMixin supports m>=2 outcomes ",
768-
):
769-
acqf = acqf_class(
770-
model=mm,
771-
ref_point=ref_point,
772-
X_baseline=X_baseline,
773-
sampler=sampler,
774-
cache_root=False,
775-
)
776-
return
777-
798+
def _test_qnehvi_base(
799+
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
800+
) -> None:
801+
(
802+
ref_point,
803+
X,
804+
X_baseline,
805+
mm,
806+
sampler,
807+
samples,
808+
baseline_samples,
809+
tkwargs,
810+
) = self._setup_qnehvi_test(dtype=dtype, m=m)
778811
acqf = acqf_class(
779812
model=mm,
780813
ref_point=ref_point,
@@ -934,6 +967,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
934967
self.assertEqual(list(b.shape), [1, 1, m])
935968
self.assertEqual(list(b.shape), [1, 1, m])
936969

970+
def _test_qnehvi_with_CBD(
971+
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
972+
) -> None:
973+
(
974+
ref_point,
975+
X,
976+
X_baseline,
977+
mm,
978+
sampler,
979+
samples,
980+
baseline_samples,
981+
tkwargs,
982+
) = self._setup_qnehvi_test(dtype=dtype, m=m)
983+
pareto_Y = self.pareto_Y_raw[:, :m].to(**tkwargs)
984+
937985
# test no baseline points
938986
ref_point2 = [15.0, 14.0, 16.0][:m]
939987
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
@@ -1146,6 +1194,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
11461194
self.assertTrue(torch.equal(acqf_pareto_Y[:-2], expected_pareto_Y))
11471195
self.assertTrue(torch.equal(acqf_pareto_Y[-2:], expected_new_Y2))
11481196

1197+
def _test_qnehvi_without_CBD(
1198+
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
1199+
) -> None:
1200+
tkwargs = {"device": self.device}
1201+
tkwargs["dtype"] = dtype
1202+
ref_point = self.ref_point[:m]
1203+
Y = self.Y_raw[:, :m].to(**tkwargs)
1204+
pareto_Y = self.pareto_Y_raw[:, :m].to(**tkwargs)
1205+
X_baseline = torch.rand(Y.shape[0], 1, **tkwargs)
1206+
# the event shape is `b x q + r x m` = 1 x 1 x 2
1207+
baseline_samples = Y
1208+
mm = MockModel(MockPosterior(samples=baseline_samples))
1209+
1210+
X_pending = torch.rand(1, 1, dtype=dtype, device=self.device)
1211+
11491212
# test qNEHVI without CBD
11501213
mm._posterior._samples = baseline_samples
11511214
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
@@ -1158,6 +1221,7 @@ def _test_q_noisy_expected_hypervolume_improvement(
11581221
cache_pending=False,
11591222
cache_root=False,
11601223
)
1224+
new_Y = torch.tensor([[0.5, 3.0, 0.5][:m]], dtype=dtype, device=self.device)
11611225
mm._posterior._samples = torch.cat(
11621226
[
11631227
baseline_samples,
@@ -1168,15 +1232,25 @@ def _test_q_noisy_expected_hypervolume_improvement(
11681232
acqf.set_X_pending(X_pending10)
11691233
self.assertTrue(torch.equal(acqf.X_pending, X_pending10))
11701234
acqf_pareto_Y = acqf.partitioning.pareto_Y[0]
1235+
expected_pareto_Y = pareto_Y if m == 2 else pareto_Y.cpu()
11711236
self.assertTrue(torch.equal(acqf_pareto_Y, expected_pareto_Y))
11721237
acqf.set_X_pending(X_pending)
1238+
# test incremental nehvi in forward
1239+
new_Y2 = torch.cat(
1240+
[
1241+
new_Y,
1242+
torch.tensor([[0.25, 9.5, 1.5][:m]], dtype=dtype, device=self.device),
1243+
],
1244+
dim=0,
1245+
)
11731246
mm._posterior._samples = torch.cat(
11741247
[
11751248
baseline_samples,
11761249
new_Y2,
11771250
]
11781251
).unsqueeze(0)
11791252
with torch.no_grad():
1253+
X_test = torch.rand(1, 1, dtype=dtype, device=self.device)
11801254
val = evaluate(acqf, X_test)
11811255
bd = DominatedPartitioning(
11821256
ref_point=torch.tensor(ref_point).to(**tkwargs), Y=pareto_Y
@@ -1212,6 +1286,10 @@ def _test_q_noisy_expected_hypervolume_improvement(
12121286
# test X_pending is not None on __init__
12131287
mm._posterior._samples = torch.zeros(1, 5, m, **tkwargs)
12141288
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
1289+
# add another point
1290+
X_pending2 = torch.cat(
1291+
[X_pending, torch.rand(1, 1, dtype=dtype, device=self.device)], dim=0
1292+
)
12151293
acqf = acqf_class(
12161294
model=mm,
12171295
ref_point=ref_point,

0 commit comments

Comments
 (0)