Skip to content

Commit f25b4ce

Browse files
esantorellafacebook-github-bot
authored andcommitted
More precise type annotations for MultiObjectiveOptimizationConfig (facebook#2620)
Summary: Pull Request resolved: facebook#2620 Context: Type annotations imply that any `Objective` works with a `MultiObjectiveOptimizationConfig`, but the code makes clear that only a `MultiObjective` or `ScalarizedObjective` works, and even has tests for this. I was misled myself into trying to privde an `Objective` to a `MultiObjectiveOptimizationConfig`. This PR: * Changes annotations from `Objective` to `Union[MultiObjective, ScalarizedObjective]` * Adds a pyre-fixme: Inconsistent override. The indirect cause of why this is needed is that `Objective.clone_with_args` returns an `Objective` type even in subclasses unless the method is overriden, rather than a self type. * Added a couple pyre-fixmes in unit tests that were deliberately testing inappropriate types. Reviewed By: saitcakmak, mpolson64 Differential Revision: D60476566 fbshipit-source-id: 03afbfdca624026e5ff8d5da3d224f6b1d676032
1 parent e9bd020 commit f25b4ce

File tree

5 files changed

+28
-14
lines changed

5 files changed

+28
-14
lines changed

ax/core/objective.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def objective_weights(self) -> Iterable[Tuple[Objective, float]]:
164164
"""Get the objectives and weights."""
165165
return zip(self.objectives, self.weights)
166166

167-
def clone(self) -> Objective:
167+
def clone(self) -> MultiObjective:
168168
"""Create a copy of the objective."""
169169
return MultiObjective(objectives=[o.clone() for o in self.objectives])
170170

@@ -235,7 +235,7 @@ def metric_weights(self) -> Iterable[Tuple[Metric, float]]:
235235
"""Get the metrics and weights."""
236236
return zip(self.metrics, self.weights)
237237

238-
def clone(self) -> Objective:
238+
def clone(self) -> ScalarizedObjective:
239239
"""Create a copy of the objective."""
240240
return ScalarizedObjective(
241241
metrics=[m.clone() for m in self.metrics],

ax/core/optimization_config.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from itertools import groupby
1010
from logging import Logger
11-
from typing import Dict, List, Optional
11+
from typing import Dict, List, Optional, Union
1212

1313
from ax.core.metric import Metric
1414
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
@@ -266,7 +266,7 @@ class MultiObjectiveOptimizationConfig(OptimizationConfig):
266266

267267
def __init__(
268268
self,
269-
objective: Objective,
269+
objective: Union[MultiObjective, ScalarizedObjective],
270270
outcome_constraints: Optional[List[OutcomeConstraint]] = None,
271271
objective_thresholds: Optional[List[ObjectiveThreshold]] = None,
272272
risk_measure: Optional[RiskMeasure] = None,
@@ -293,14 +293,15 @@ def __init__(
293293
objective_thresholds=objective_thresholds,
294294
risk_measure=risk_measure,
295295
)
296-
self._objective: Objective = objective
296+
self._objective: Union[MultiObjective, ScalarizedObjective] = objective
297297
self._outcome_constraints: List[OutcomeConstraint] = constraints
298298
self._objective_thresholds: List[ObjectiveThreshold] = objective_thresholds
299299
self.risk_measure: Optional[RiskMeasure] = risk_measure
300300

301+
# pyre-fixme[14]: Inconsistent override.
301302
def clone_with_args(
302303
self,
303-
objective: Optional[Objective] = None,
304+
objective: Optional[Union[MultiObjective, ScalarizedObjective]] = None,
304305
outcome_constraints: Optional[
305306
List[OutcomeConstraint]
306307
] = _NO_OUTCOME_CONSTRAINTS,
@@ -333,12 +334,12 @@ def clone_with_args(
333334
)
334335

335336
@property
336-
def objective(self) -> Objective:
337+
def objective(self) -> Union[MultiObjective, ScalarizedObjective]:
337338
"""Get objective."""
338339
return self._objective
339340

340341
@objective.setter
341-
def objective(self, objective: Objective) -> None:
342+
def objective(self, objective: Union[MultiObjective, ScalarizedObjective]) -> None:
342343
"""Set objective if not present in outcome constraints."""
343344
self._validate_optimization_config(
344345
objective=objective,

ax/core/tests/test_optimization_config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,12 @@ def test_Init(self) -> None:
334334
objective=self.multi_objective, outcome_constraints=self.outcome_constraints
335335
)
336336
self.assertEqual(str(config1), MOOC_STR)
337-
with self.assertRaises(TypeError):
337+
with self.assertRaisesRegex(
338+
TypeError,
339+
"`MultiObjectiveOptimizationConfig` requires an objective of type "
340+
"`MultiObjective` or `ScalarizedObjective`.",
341+
):
342+
# pyre-fixme [8]: Incompatible attribute type
338343
config1.objective = self.objective # Wrong objective type
339344
# updating constraints is fine.
340345
config1.outcome_constraints = [self.outcome_constraint]
@@ -428,7 +433,12 @@ def test_Eq(self) -> None:
428433

429434
def test_ConstraintValidation(self) -> None:
430435
# Cannot build with non-MultiObjective
431-
with self.assertRaises(TypeError):
436+
with self.assertRaisesRegex(
437+
TypeError,
438+
"`MultiObjectiveOptimizationConfig` requires an objective of type "
439+
"`MultiObjective` or `ScalarizedObjective`.",
440+
):
441+
# pyre-fixme [6]: Incompatible parameter type
432442
MultiObjectiveOptimizationConfig(objective=self.objective)
433443

434444
# Using an outcome constraint for an objective should raise

ax/storage/sqa_store/decoder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from ax.utils.common.logger import get_logger
7878
from ax.utils.common.typeutils import not_none
7979
from pandas import read_json
80+
from pyre_extensions import assert_is_instance
8081
from sqlalchemy.orm.exc import DetachedInstanceError
8182

8283
logger: Logger = get_logger(__name__)
@@ -585,7 +586,9 @@ def opt_config_and_tracking_metrics_from_sqa(
585586

586587
if objective_thresholds or type(objective) is MultiObjective:
587588
optimization_config = MultiObjectiveOptimizationConfig(
588-
objective=objective,
589+
objective=assert_is_instance(
590+
objective, Union[MultiObjective, ScalarizedObjective]
591+
),
589592
outcome_constraints=outcome_constraints,
590593
objective_thresholds=objective_thresholds,
591594
risk_measure=risk_measure,

ax/utils/testing/core_stubs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ def get_map_objective(minimize: bool = False) -> Objective:
15831583
return Objective(metric=MapMetric(name="m1"), minimize=minimize)
15841584

15851585

1586-
def get_multi_objective() -> Objective:
1586+
def get_multi_objective() -> MultiObjective:
15871587
return MultiObjective(
15881588
objectives=[
15891589
Objective(metric=Metric(name="m1"), minimize=False),
@@ -1592,7 +1592,7 @@ def get_multi_objective() -> Objective:
15921592
)
15931593

15941594

1595-
def get_custom_multi_objective() -> Objective:
1595+
def get_custom_multi_objective() -> MultiObjective:
15961596
return MultiObjective(
15971597
objectives=[
15981598
Objective(
@@ -1633,7 +1633,7 @@ def get_branin_objective(name: str = "branin", minimize: bool = False) -> Object
16331633
)
16341634

16351635

1636-
def get_branin_multi_objective(num_objectives: int = 2) -> Objective:
1636+
def get_branin_multi_objective(num_objectives: int = 2) -> MultiObjective:
16371637
_validate_num_objectives(num_objectives=num_objectives)
16381638
objectives = [
16391639
Objective(metric=get_branin_metric(name="branin_a"), minimize=True),

0 commit comments

Comments
 (0)