Skip to content

Commit 42b5b95

Browse files
mpolson64facebook-github-bot
authored andcommitted
Remove isinstance(foo, MulitTypeExperiment) checks (#4740)
Summary: As titled. Should be a no-op since these are still initialized as MultiTypeExperiment. Next diff will remove instantiations. Differential Revision: D90133237 Privacy Context Container: L1413903
1 parent bfdd43b commit 42b5b95

File tree

6 files changed

+19
-22
lines changed

6 files changed

+19
-22
lines changed

ax/core/experiment.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,8 +1649,12 @@ def __repr__(self) -> str:
16491649

16501650
@property
16511651
def is_multi_type(self) -> bool:
1652-
"""Whether this Experiment contains more than one trial type."""
1653-
return len(self._trial_type_to_runner) > 1
1652+
"""
1653+
Whether this Experiment supports at least one non-None trial type.
1654+
"""
1655+
supported_types = {*self._trial_type_to_runner.keys()}
1656+
1657+
return len(supported_types - {None}) > 0
16541658

16551659
@property
16561660
def default_trial_type(self) -> str | None:

ax/orchestration/orchestrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def trial_type(self) -> str | None:
362362
Trial type for the experiment this Orchestrator is running if the
363363
experiment is a MultiTypeExperiment and None otherwise.
364364
"""
365-
if isinstance(self.experiment, MultiTypeExperiment):
365+
if self.experiment.is_multi_type:
366366
return self.options.mt_experiment_trial_type
367367
return None
368368

@@ -1576,7 +1576,7 @@ def _validate_options(self, options: OrchestratorOptions) -> None:
15761576
"will be unable to fetch intermediate results with which to "
15771577
"evaluate early stopping criteria."
15781578
)
1579-
if isinstance(self.experiment, MultiTypeExperiment):
1579+
if self.experiment.is_multi_type:
15801580
if options.mt_experiment_trial_type is None:
15811581
raise UserInputError(
15821582
"Must specify `mt_experiment_trial_type` for MultiTypeExperiment."

ax/orchestration/tests/test_orchestrator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,7 +1307,7 @@ def test_orchestrator_with_metric_with_new_data_after_completion(self) -> None:
13071307
init_test_engine_and_session_factory(force_init=True)
13081308
branin_gs = self.two_sobol_steps_GS
13091309
# With runners & metrics, `Orchestrator.run_all_trials` should run.
1310-
if isinstance(self.branin_experiment, MultiTypeExperiment):
1310+
if self.branin_experiment.is_multi_type:
13111311
self.branin_experiment.update_runner(
13121312
"type1", SyntheticRunnerWithPredictableStatusPolling()
13131313
)
@@ -2209,7 +2209,7 @@ def test_generate_candidates_works_for_sobol(self) -> None:
22092209
gs = get_online_sobol_mbm_generation_strategy()
22102210

22112211
# this is a HITL experiment, so we don't want trials completing on their own.
2212-
if isinstance(self.branin_experiment, MultiTypeExperiment):
2212+
if self.branin_experiment.is_multi_type:
22132213
self.branin_experiment.update_runner("type1", InfinitePollRunner())
22142214
else:
22152215
self.branin_experiment.runner = InfinitePollRunner()
@@ -2259,7 +2259,7 @@ def test_generate_candidates_can_remove_stale_candidates_with_ttl(
22592259
gs = self.two_sobol_steps_GS
22602260

22612261
# this is a HITL experiment, so we don't want trials completing on their own.
2262-
if isinstance(self.branin_experiment, MultiTypeExperiment):
2262+
if self.branin_experiment.is_multi_type:
22632263
self.branin_experiment.update_runner("type1", InfinitePollRunner())
22642264
else:
22652265
self.branin_experiment.runner = InfinitePollRunner()
@@ -2312,7 +2312,7 @@ def test_generate_candidates_can_remove_stale_candidates(self) -> None:
23122312
gs = self.two_sobol_steps_GS
23132313

23142314
# this is a HITL experiment, so we don't want trials completing on their own.
2315-
if isinstance(self.branin_experiment, MultiTypeExperiment):
2315+
if self.branin_experiment.is_multi_type:
23162316
self.branin_experiment.update_runner("type1", InfinitePollRunner())
23172317
else:
23182318
self.branin_experiment.runner = InfinitePollRunner()
@@ -2389,7 +2389,7 @@ def test_generate_candidates_does_not_fail_stale_candidates_if_fails_to_gen(
23892389
gs = self.two_sobol_steps_GS
23902390

23912391
# this is a HITL experiment, so we don't want trials completing on their own.
2392-
if isinstance(self.branin_experiment, MultiTypeExperiment):
2392+
if self.branin_experiment.is_multi_type:
23932393
self.branin_experiment.update_runner("type1", InfinitePollRunner())
23942394
else:
23952395
self.branin_experiment.runner = InfinitePollRunner()

ax/service/ax_client.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ax.core.evaluations_to_data import raw_evaluations_to_data
2525
from ax.core.experiment import Experiment
2626
from ax.core.generator_run import GeneratorRun
27-
from ax.core.multi_type_experiment import MultiTypeExperiment
2827
from ax.core.objective import MultiObjective, Objective
2928
from ax.core.observation import ObservationFeatures
3029
from ax.core.runner import Runner
@@ -445,14 +444,10 @@ def add_tracking_metrics(
445444
for metric_name in metric_names
446445
]
447446

448-
if isinstance(self.experiment, MultiTypeExperiment):
449-
experiment = assert_is_instance(self.experiment, MultiTypeExperiment)
450-
experiment.add_tracking_metrics(
451-
metrics=metric_objects,
452-
metrics_to_trial_types=metrics_to_trial_types,
453-
)
454-
else:
455-
self.experiment.add_tracking_metrics(metrics=metric_objects)
447+
self.experiment.add_tracking_metrics(
448+
metrics=metric_objects,
449+
metrics_to_trial_types=metrics_to_trial_types,
450+
)
456451

457452
@copy_doc(Experiment.remove_tracking_metric)
458453
def remove_tracking_metric(self, metric_name: str) -> None:

ax/service/utils/report_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from ax.core.generator_run import GeneratorRunType
3333
from ax.core.map_metric import MapMetric
3434
from ax.core.metric import Metric
35-
from ax.core.multi_type_experiment import MultiTypeExperiment
3635
from ax.core.objective import MultiObjective, ScalarizedObjective
3736
from ax.core.optimization_config import (
3837
MultiObjectiveOptimizationConfig,
@@ -787,7 +786,7 @@ def exp_to_df(
787786
)
788787

789788
# Accept Experiment and SimpleExperiment
790-
if isinstance(exp, MultiTypeExperiment):
789+
if exp.is_multi_type:
791790
raise ValueError("Cannot transform MultiTypeExperiments to DataFrames.")
792791

793792
key_components = ["trial_index", "arm_name"]

ax/storage/sqa_store/encoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from ax.core.experiment import Experiment
2929
from ax.core.generator_run import GeneratorRun
3030
from ax.core.metric import Metric
31-
from ax.core.multi_type_experiment import MultiTypeExperiment
3231
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
3332
from ax.core.optimization_config import (
3433
MultiObjectiveOptimizationConfig,
@@ -212,7 +211,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
212211
]
213212
auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons
214213
runners = []
215-
if isinstance(experiment, MultiTypeExperiment):
214+
if experiment.is_multi_type:
216215
experiment._properties[Keys.SUBCLASS] = "MultiTypeExperiment"
217216
for trial_type, runner in experiment._trial_type_to_runner.items():
218217
runner_sqa = self.runner_to_sqa(none_throws(runner), trial_type)

0 commit comments

Comments
 (0)