Skip to content

Commit f0711fc

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Use get_default_partitioning_alpha for NEHVI input constructor (#1481)
Summary: Pull Request resolved: #1481 This brings Ax MBM behavior on par with legacy models. Reviewed By: Balandat Differential Revision: D41087000 fbshipit-source-id: 01c0e27095c1bba44eb34e4f0e54cc9422070fc1
1 parent 7563a0b commit f0711fc

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def construct_inputs_qNEHVI(
826826
"X_pending": kwargs.get("X_pending"),
827827
"eta": kwargs.get("eta", 1e-3),
828828
"prune_baseline": kwargs.get("prune_baseline", True),
829-
"alpha": kwargs.get("alpha", 0.0),
829+
"alpha": kwargs.get("alpha", get_default_partitioning_alpha(model.num_outputs)),
830830
"cache_pending": kwargs.get("cache_pending", True),
831831
"max_iep": kwargs.get("max_iep", 0),
832832
"incremental_nehvi": kwargs.get("incremental_nehvi", True),

test/acquisition/test_input_constructors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ def test_construct_inputs_qNEHVI(self):
690690
c = get_acqf_input_constructor(qNoisyExpectedHypervolumeImprovement)
691691
objective_thresholds = torch.rand(2)
692692
mock_model = mock.Mock()
693+
mock_model.num_outputs = 2
693694

694695
# Test defaults
695696
kwargs = c(
@@ -793,6 +794,15 @@ def test_construct_inputs_qNEHVI(self):
793794
self.assertIs(kwargs["objective"], obj)
794795
self.assertTrue(torch.equal(kwargs["ref_point"], expected_obj_t))
795796

797+
# Test default alpha for many objectives/
798+
mock_model.num_outputs = 5
799+
kwargs = c(
800+
model=mock_model,
801+
training_data=self.blockX_blockY,
802+
objective_thresholds=objective_thresholds,
803+
)
804+
self.assertEqual(kwargs["alpha"], 1e-3)
805+
796806
def test_construct_inputs_kg(self):
797807
current_value = torch.tensor(1.23)
798808
with mock.patch(

0 commit comments

Comments
 (0)