Skip to content

Commit 3b674eb

Browse files
Jelena Markovic-Voronovfacebook-github-bot
authored andcommitted
ST_MTGP model for MTGP MOO (#1962)
Summary: Pull Request resolved: #1962 - Changed the legacy Ax model ST_MTGP_NEHVI into the new modular botorch model ST_MTGP. - Changed the number of objectives in MOO problems to not depend on the number of constraints but only on the number of objectives. Reviewed By: saitcakmak Differential Revision: D47739378 fbshipit-source-id: 9d0bab44d589c1e1d61c47f8be430c122539b884
1 parent ff3a394 commit 3b674eb

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,6 @@ def construct_inputs_qNEHVI(
986986
first_only=True,
987987
assert_shared=True,
988988
)
989-
990989
# This selects the objectives (a subset of the outcomes) and set each
991990
# objective threhsold to have the proper optimization direction.
992991
if objective is None:
@@ -1014,6 +1013,8 @@ def construct_inputs_qNEHVI(
10141013
else:
10151014
ref_point = objective(objective_thresholds)
10161015

1016+
num_objectives = objective_thresholds[~torch.isnan(objective_thresholds)].shape[0]
1017+
10171018
return {
10181019
"model": model,
10191020
"ref_point": ref_point,
@@ -1024,7 +1025,7 @@ def construct_inputs_qNEHVI(
10241025
"X_pending": kwargs.get("X_pending"),
10251026
"eta": kwargs.get("eta", 1e-3),
10261027
"prune_baseline": kwargs.get("prune_baseline", True),
1027-
"alpha": kwargs.get("alpha", get_default_partitioning_alpha(model.num_outputs)),
1028+
"alpha": kwargs.get("alpha", get_default_partitioning_alpha(num_objectives)),
10281029
"cache_pending": kwargs.get("cache_pending", True),
10291030
"max_iep": kwargs.get("max_iep", 0),
10301031
"incremental_nehvi": kwargs.get("incremental_nehvi", True),

test/acquisition/test_input_constructors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def test_construct_inputs_qNEHVI(self):
810810
X_pending=X_pending,
811811
eta=1e-2,
812812
prune_baseline=True,
813-
alpha=0.1,
813+
alpha=0.0,
814814
cache_pending=False,
815815
max_iep=1,
816816
incremental_nehvi=False,
@@ -831,7 +831,7 @@ def test_construct_inputs_qNEHVI(self):
831831
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
832832
self.assertEqual(kwargs["eta"], 1e-2)
833833
self.assertTrue(kwargs["prune_baseline"])
834-
self.assertEqual(kwargs["alpha"], 0.1)
834+
self.assertEqual(kwargs["alpha"], 0.0)
835835
self.assertFalse(kwargs["cache_pending"])
836836
self.assertEqual(kwargs["max_iep"], 1)
837837
self.assertFalse(kwargs["incremental_nehvi"])
@@ -874,7 +874,7 @@ def test_construct_inputs_qNEHVI(self):
874874
training_data=self.blockX_blockY,
875875
objective_thresholds=objective_thresholds,
876876
)
877-
self.assertEqual(kwargs["alpha"], 1e-3)
877+
self.assertEqual(kwargs["alpha"], 0.0)
878878

879879
def test_construct_inputs_kg(self):
880880
current_value = torch.tensor(1.23)

0 commit comments

Comments
 (0)