diff --git a/ax/core/experiment.py b/ax/core/experiment.py index fd0ca7bf8b5..4c7846a551c 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -9,7 +9,6 @@ from __future__ import annotations import inspect - import logging import warnings from collections import defaultdict @@ -150,10 +149,13 @@ def __init__( self._trials: dict[int, BaseTrial] = {} self._properties: dict[str, Any] = properties or {} + # Specifies which trial type each metric belongs to + self._metric_to_trial_type: dict[str, str | None] = {} + # Initialize trial type to runner mapping - self._default_trial_type = default_trial_type + self._default_trial_type: str | None = default_trial_type self._trial_type_to_runner: dict[str | None, Runner | None] = { - default_trial_type: runner + self._default_trial_type: runner } # Used to keep track of whether any trials on the experiment # specify a TTL. Since trials need to be checked for their TTL's @@ -417,13 +419,13 @@ def runner(self, runner: Runner | None) -> None: if runner is not None: self._trial_type_to_runner[self._default_trial_type] = runner else: - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner = {self._default_trial_type: None} @runner.deleter def runner(self) -> None: """Delete the runner.""" self._runner = None - self._trial_type_to_runner = {None: None} + self._trial_type_to_runner = {self._default_trial_type: None} @property def parameters(self) -> dict[str, Parameter]: @@ -493,6 +495,11 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None: for metric_name in optimization_config.metrics.keys(): if metric_name in self._tracking_metrics: self.remove_tracking_metric(metric_name) + + # Optimization config metrics are required to be the default trial type + # currently. TODO: remove that restriction (T202797235) + self._metric_to_trial_type[metric_name] = self.default_trial_type + # add metrics from the previous optimization config that are not in the new # optimization config as tracking metrics prev_optimization_config = self._optimization_config @@ -554,11 +561,16 @@ def immutable_search_space_and_opt_config(self) -> bool: def tracking_metrics(self) -> list[Metric]: return list(self._tracking_metrics.values()) - def add_tracking_metric(self, metric: Metric) -> Experiment: + def add_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Experiment: """Add a new metric to the experiment. Args: metric: Metric to be added. + trial_type: The trial type for which this metric is used. """ if metric.name in self._tracking_metrics: raise ValueError( @@ -574,10 +586,19 @@ def add_tracking_metric(self, metric: Metric) -> Experiment: "before adding it to tracking metrics." ) + if trial_type is None: + trial_type = self._default_trial_type + self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = trial_type + return self - def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + ) -> Experiment: """Add a list of new metrics to the experiment. If any of the metrics are already defined on the experiment, @@ -585,23 +606,52 @@ def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment: Args: metrics: Metrics to be added. + metrics_to_trial_types: The mapping from metric names to corresponding + trial types for each metric. If provided, the metrics will be + added to their trial types. If not provided, then the default + trial type will be used. """ - # Before setting any metrics, we validate none are already on - # the experiment + metrics_to_trial_types = metrics_to_trial_types or {} + for metric in metrics: - self.add_tracking_metric(metric) + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get(metric.name), + ) return self - def update_tracking_metric(self, metric: Metric) -> Experiment: + def update_tracking_metric( + self, + metric: Metric, + trial_type: str | None = None, + ) -> Experiment: """Redefine a metric that already exists on the experiment. Args: metric: New metric definition. + trial_type: The trial type for which this metric is used. """ + if trial_type is None: + trial_type = self._default_trial_type + + oc = self.optimization_config + oc_metrics = oc.metrics if oc else [] + if metric.name in oc_metrics and trial_type != self._default_trial_type: + raise ValueError( + f"Metric `{metric.name}` must remain a " + f"`{self._default_trial_type}` metric because it is part of the " + "optimization_config." + ) + + if not self.supports_trial_type(trial_type): + raise ValueError(f"`{trial_type}` is not a supported trial type.") + if metric.name not in self._tracking_metrics: raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.") self._tracking_metrics[metric.name] = metric + self._metric_to_trial_type[metric.name] = trial_type + return self def remove_tracking_metric(self, metric_name: str) -> Experiment: @@ -614,6 +664,8 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment: raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") del self._tracking_metrics[metric_name] + del self._metric_to_trial_type[metric_name] + return self @property @@ -777,6 +829,21 @@ def fetch_data( Returns: Data for the experiment. """ + if self.is_multi_type: + # TODO: make this more efficient for fetching + # data for multiple trials of the same type + # by overriding Experiment._lookup_or_fetch_trials_results + return Data.from_multiple_data( + [ + ( + trial.fetch_data(**kwargs, metrics=metrics) + if trial.status.expecting_data + else Data() + ) + for trial in self.trials.values() + ] + ) + results = self._lookup_or_fetch_trials_results( trials=list(self.trials.values()) if trial_indices is None @@ -853,8 +920,16 @@ def _fetch_trial_data( ) -> dict[str, MetricFetchResult]: trial = self.trials[trial_index] + metrics_for_trial_type = [ + metric + for metric in metrics or self.metrics.values() + if self.metric_to_trial_type[metric.name] == trial.trial_type + ] + trial_data = self._lookup_or_fetch_trials_results( - trials=[trial], metrics=metrics, **kwargs + trials=[trial], + metrics=metrics_for_trial_type, + **kwargs, ) if trial_index in trial_data: @@ -1566,19 +1641,43 @@ def __repr__(self) -> str: return self.__class__.__name__ + f"({self._name})" # --- MultiTypeExperiment convenience functions --- - # - # Certain functionalities have special behavior for multi-type experiments. - # This defines the base behavior for regular experiments that will be - # overridden in the MultiTypeExperiment class. + # A canonical use case for this is tuning a large production system + # with limited evaluation budget and a simulator which approximates + # evaluations on the main system. Trial deployment and data fetching + # is separate for the two systems, but the final data is combined and + # fed into multi-task models. + + @property + def is_multi_type(self) -> bool: + """Whether this Experiment contains more than one trial type.""" + return len(self._trial_type_to_runner) > 1 @property def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment. + """Default trial type assigned to trials in this experiment.""" + return self._default_trial_type - In the base experiment class this is always None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + @property + def default_trials(self) -> set[int]: + """Return the indicies for trials of the default type.""" + return { + idx + for idx, trial in self.trials.items() + if trial.trial_type == self.default_trial_type + } + + def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment": + """Add a new trial_type to be supported by this experiment. + + Args: + trial_type: The new trial_type to be added. + runner: The default runner for trials of this type. """ - return self._default_trial_type + if self.supports_trial_type(trial_type): + raise ValueError(f"Experiment already contains trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + return self def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: """The default runner to use for a given trial type. @@ -1591,20 +1690,55 @@ def runner_for_trial_type(self, trial_type: str | None) -> Runner | None: return self.runner # return the default runner return runner + def update_runner(self, trial_type: str, runner: Runner) -> "Experiment": + """Update the default runner for an existing trial_type. + + Args: + trial_type: The new trial_type to be added. + runner: The new runner for trials of this type. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") + + self._trial_type_to_runner[trial_type] = runner + self._runner = runner + return self + + @property + def metric_to_trial_type(self) -> dict[str, str]: + """Map metrics to trial types. + + Adds in default trial type for OC metrics to custom defined trial types.. + """ + if self.optimization_config is not None: + opt_config_types = { + metric_name: self.default_trial_type + for metric_name in self.optimization_config.metrics.keys() + } + else: + opt_config_types = {} + + return {**opt_config_types, **self._metric_to_trial_type} + + def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: + """The default runner to use for a given trial type. + + Looks up the appropriate runner for this trial type in the trial_type_to_runner. + """ + if not self.supports_trial_type(trial_type): + raise ValueError(f"Trial type `{trial_type}` is not supported.") + return [ + self.metrics[metric_name] + for metric_name, metric_trial_type in self._metric_to_trial_type.items() + if metric_trial_type == trial_type + ] + def supports_trial_type(self, trial_type: str | None) -> bool: """Whether this experiment allows trials of the given type. - The base experiment class only supports None. For experiments - with multiple trial types, use the MultiTypeExperiment class. + Only trial types defined in the trial_type_to_runner are allowed. """ - return ( - trial_type is None - # We temporarily allow "short run" and "long run" trial - # types in single-type experiments during development of - # a new ``GenerationStrategy`` that needs them. - or trial_type == Keys.SHORT_RUN - or trial_type == Keys.LONG_RUN - ) + return trial_type in self._trial_type_to_runner.keys() def attach_trial( self, @@ -2206,3 +2340,44 @@ def add_arm_and_prevent_naming_collision( stacklevel=2, ) new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True)) + + +def filter_trials_by_type( + trials: Sequence[BaseTrial], trial_type: str | None +) -> list[BaseTrial]: + """Filter trials by trial type if provided. + + This filters trials by trial type if the experiment is a + MultiTypeExperiment. + + Args: + trials: Trials to filter. + + Returns: + Filtered trials. + """ + if trial_type is not None: + return [t for t in trials if t.trial_type == trial_type] + return list(trials) + + +def get_trial_indices_for_statuses( + experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None +) -> set[int]: + """Get trial indices for a set of statuses. + + Args: + statuses: Set of statuses to get trial indices for. + + Returns: + Set of trial indices for the given statuses. + """ + return { + i + for i, t in experiment.trials.items() + if (t.status in statuses) + and ( + (trial_type is None) + or ((trial_type is not None) and (t.trial_type == trial_type)) + ) + } diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index d40884ae93d..fd709fba2f4 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -6,360 +6,17 @@ # pyre-strict -from collections.abc import Iterable, Sequence -from typing import Any - -from ax.core.arm import Arm -from ax.core.base_trial import BaseTrial, TrialStatus -from ax.core.data import Data from ax.core.experiment import Experiment -from ax.core.metric import Metric, MetricFetchResult -from ax.core.optimization_config import OptimizationConfig -from ax.core.runner import Runner -from ax.core.search_space import SearchSpace -from ax.utils.common.docutils import copy_doc -from pyre_extensions import none_throws class MultiTypeExperiment(Experiment): - """Class for experiment with multiple trial types. - - A canonical use case for this is tuning a large production system - with limited evaluation budget and a simulator which approximates - evaluations on the main system. Trial deployment and data fetching - is separate for the two systems, but the final data is combined and - fed into multi-task models. - - See the Multi-Task Modeling tutorial for more details. - - Attributes: - name: Name of the experiment. - description: Description of the experiment. """ + Deprecated. - def __init__( - self, - name: str, - search_space: SearchSpace, - default_trial_type: str, - default_runner: Runner | None, - optimization_config: OptimizationConfig | None = None, - tracking_metrics: list[Metric] | None = None, - status_quo: Arm | None = None, - description: str | None = None, - is_test: bool = False, - experiment_type: str | None = None, - properties: dict[str, Any] | None = None, - default_data_type: Any = None, - ) -> None: - """Inits Experiment. - - Args: - name: Name of the experiment. - search_space: Search space of the experiment. - default_trial_type: Default type for trials on this experiment. - default_runner: Default runner for trials of the default type. - optimization_config: Optimization config of the experiment. - tracking_metrics: Additional tracking metrics not used for optimization. - These are associated with the default trial type. - runner: Default runner used for trials on this experiment. - status_quo: Arm representing existing "control" arm. - description: Description of the experiment. - is_test: Convenience metadata tracker for the user to mark test experiments. - experiment_type: The class of experiments this one belongs to. - properties: Dictionary of this experiment's properties. - default_data_type: Deprecated and ignored. - """ - - # Specifies which trial type each metric belongs to - self._metric_to_trial_type: dict[str, str] = {} - - # Maps certain metric names to a canonical name. Useful for ancillary trial - # types' metrics, to specify which primary metrics they correspond to - # (e.g. 'comment_prediction' => 'comment') - self._metric_to_canonical_name: dict[str, str] = {} - - # call super.__init__() after defining fields above, because we need - # them to be populated before optimization config is set - super().__init__( - name=name, - search_space=search_space, - optimization_config=optimization_config, - status_quo=status_quo, - description=description, - is_test=is_test, - experiment_type=experiment_type, - properties=properties, - tracking_metrics=tracking_metrics, - runner=default_runner, - default_trial_type=default_trial_type, - default_data_type=default_data_type, - ) - - def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": - """Add a new trial_type to be supported by this experiment. - - Args: - trial_type: The new trial_type to be added. - runner: The default runner for trials of this type. - """ - if self.supports_trial_type(trial_type): - raise ValueError(f"Experiment already contains trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - return self - - # pyre-fixme [56]: Pyre was not able to infer the type of the decorator - # `Experiment.optimization_config.setter`. - @Experiment.optimization_config.setter - def optimization_config(self, optimization_config: OptimizationConfig) -> None: - # pyre-fixme [16]: `Optional` has no attribute `fset`. - Experiment.optimization_config.fset(self, optimization_config) - for metric_name in optimization_config.metrics.keys(): - # Optimization config metrics are required to be the default trial type - # currently. TODO: remove that restriction (T202797235) - self._metric_to_trial_type[metric_name] = none_throws( - self.default_trial_type - ) - - def update_runner(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": - """Update the default runner for an existing trial_type. - - Args: - trial_type: The new trial_type to be added. - runner: The new runner for trials of this type. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Experiment does not contain trial_type `{trial_type}`") - - self._trial_type_to_runner[trial_type] = runner - self._runner = runner - return self - - def add_tracking_metric( - self, - metric: Metric, - trial_type: str | None = None, - canonical_name: str | None = None, - ) -> "MultiTypeExperiment": - """Add a new metric to the experiment. - - Args: - metric: The metric to add. - trial_type: The trial type for which this metric is used. - canonical_name: The default metric for which this metric is a proxy. - """ - if trial_type is None: - trial_type = self._default_trial_type - if not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().add_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = none_throws(trial_type) - if canonical_name is not None: - self._metric_to_canonical_name[metric.name] = canonical_name - return self - - def add_tracking_metrics( - self, - metrics: list[Metric], - metrics_to_trial_types: dict[str, str] | None = None, - canonical_names: dict[str, str] | None = None, - ) -> Experiment: - """Add a list of new metrics to the experiment. - - If any of the metrics are already defined on the experiment, - we raise an error and don't add any of them to the experiment - - Args: - metrics: Metrics to be added. - metrics_to_trial_types: The mapping from metric names to corresponding - trial types for each metric. If provided, the metrics will be - added to their trial types. If not provided, then the default - trial type will be used. - canonical_names: A mapping of metric names to their - canonical names(The default metrics for which the metrics are - proxies.) - - Returns: - The experiment with the added metrics. - """ - metrics_to_trial_types = metrics_to_trial_types or {} - canonical_name = None - for metric in metrics: - if canonical_names is not None: - canonical_name = none_throws(canonical_names).get(metric.name, None) - - self.add_tracking_metric( - metric=metric, - trial_type=metrics_to_trial_types.get( - metric.name, self._default_trial_type - ), - canonical_name=canonical_name, - ) - return self - - # pyre-fixme[14]: `update_tracking_metric` overrides method defined in - # `Experiment` inconsistently. - def update_tracking_metric( - self, metric: Metric, trial_type: str, canonical_name: str | None = None - ) -> "MultiTypeExperiment": - """Update an existing metric on the experiment. + Functionality has been upstreamed to Experiment; please use Experiment directly + instead of MultiTypeExperiment. - Args: - metric: The metric to add. - trial_type: The trial type for which this metric is used. - canonical_name: The default metric for which this metric is a proxy. - """ - oc = self.optimization_config - oc_metrics = oc.metrics if oc else [] - if metric.name in oc_metrics and trial_type != self._default_trial_type: - raise ValueError( - f"Metric `{metric.name}` must remain a " - f"`{self._default_trial_type}` metric because it is part of the " - "optimization_config." - ) - elif not self.supports_trial_type(trial_type): - raise ValueError(f"`{trial_type}` is not a supported trial type.") - - super().update_tracking_metric(metric) - self._metric_to_trial_type[metric.name] = trial_type - if canonical_name is not None: - self._metric_to_canonical_name[metric.name] = canonical_name - return self - - @copy_doc(Experiment.remove_tracking_metric) - def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment": - if metric_name not in self._tracking_metrics: - raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.") - - # Required fields - del self._tracking_metrics[metric_name] - del self._metric_to_trial_type[metric_name] - - # Optional - if metric_name in self._metric_to_canonical_name: - del self._metric_to_canonical_name[metric_name] - return self - - @copy_doc(Experiment.fetch_data) - def fetch_data( - self, - trial_indices: Iterable[int] | None = None, - metrics: list[Metric] | None = None, - **kwargs: Any, - ) -> Data: - # TODO: make this more efficient for fetching - # data for multiple trials of the same type - # by overriding Experiment._lookup_or_fetch_trials_results - return Data.from_multiple_data( - [ - ( - trial.fetch_data(**kwargs, metrics=metrics) - if trial.status.expecting_data - else Data() - ) - for trial in self.trials.values() - ] - ) - - @copy_doc(Experiment._fetch_trial_data) - def _fetch_trial_data( - self, trial_index: int, metrics: list[Metric] | None = None, **kwargs: Any - ) -> dict[str, MetricFetchResult]: - trial = self.trials[trial_index] - metrics = [ - metric - for metric in (metrics or self.metrics.values()) - if self.metric_to_trial_type[metric.name] == trial.trial_type - ] - # Invoke parent's fetch method using only metrics for this trial_type - return super()._fetch_trial_data(trial.index, metrics=metrics, **kwargs) - - @property - def default_trials(self) -> set[int]: - """Return the indicies for trials of the default type.""" - return { - idx - for idx, trial in self.trials.items() - if trial.trial_type == self.default_trial_type - } - - @property - def metric_to_trial_type(self) -> dict[str, str]: - """Map metrics to trial types. - - Adds in default trial type for OC metrics to custom defined trial types.. - """ - opt_config_types = { - metric_name: self.default_trial_type - for metric_name in self.optimization_config.metrics.keys() - } - return {**opt_config_types, **self._metric_to_trial_type} - - # -- Overridden functions from Base Experiment Class -- - @property - def default_trial_type(self) -> str | None: - """Default trial type assigned to trials in this experiment.""" - return self._default_trial_type - - def metrics_for_trial_type(self, trial_type: str) -> list[Metric]: - """The default runner to use for a given trial type. - - Looks up the appropriate runner for this trial type in the trial_type_to_runner. - """ - if not self.supports_trial_type(trial_type): - raise ValueError(f"Trial type `{trial_type}` is not supported.") - return [ - self.metrics[metric_name] - for metric_name, metric_trial_type in self._metric_to_trial_type.items() - if metric_trial_type == trial_type - ] - - def supports_trial_type(self, trial_type: str | None) -> bool: - """Whether this experiment allows trials of the given type. - - Only trial types defined in the trial_type_to_runner are allowed. - """ - return trial_type in self._trial_type_to_runner.keys() - - -def filter_trials_by_type( - trials: Sequence[BaseTrial], trial_type: str | None -) -> list[BaseTrial]: - """Filter trials by trial type if provided. - - This filters trials by trial type if the experiment is a - MultiTypeExperiment. - - Args: - trials: Trials to filter. - - Returns: - Filtered trials. + Class remains instantiable for backwards compatibility. """ - if trial_type is not None: - return [t for t in trials if t.trial_type == trial_type] - return list(trials) - - -def get_trial_indices_for_statuses( - experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None -) -> set[int]: - """Get trial indices for a set of statuses. - Args: - statuses: Set of statuses to get trial indices for. - - Returns: - Set of trial indices for the given statuses. - """ - return { - i - for i, t in experiment.trials.items() - if (t.status in statuses) - and ( - (trial_type is None) - or ((trial_type is not None) and (t.trial_type == trial_type)) - ) - } + pass diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index a2fa2f68d77..b89b41ad2ed 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -14,10 +14,15 @@ from ax.core import BatchTrial, Experiment, Trial from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose + from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.data import Data from ax.core.evaluations_to_data import raw_evaluations_to_data -from ax.core.experiment import sort_by_trial_index_and_arm_name +from ax.core.experiment import ( + filter_trials_by_type, + get_trial_indices_for_statuses, + sort_by_trial_index_and_arm_name, +) from ax.core.map_metric import MapMetric from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -68,6 +73,7 @@ get_experiment_with_data, get_experiment_with_map_data_type, get_experiment_with_observations, + get_multi_type_experiment, get_optimization_config, get_optimization_config_no_constraints, get_scalarized_outcome_constraint, @@ -2463,3 +2469,273 @@ def test_extract_relevant_trials(self) -> None: ) self.assertEqual(len(trials), 1) self.assertEqual(trials[0].index, 0) + + +class MultiTypeExperimentTest(TestCase): + def setUp(self) -> None: + super().setUp() + self.experiment = get_multi_type_experiment() + + def test_MTExperimentFlow(self) -> None: + self.assertTrue(self.experiment.supports_trial_type("type1")) + self.assertTrue(self.experiment.supports_trial_type("type2")) + self.assertFalse(self.experiment.supports_trial_type(None)) + + n = 10 + arms = get_branin_arms(n=n, seed=0) + + b1 = self.experiment.new_batch_trial() + b1.add_arms_and_weights(arms=arms) + self.assertEqual(b1.trial_type, "type1") + b1.run() + self.assertEqual(b1.run_metadata["dummy_metadata"], "dummy1") + + self.experiment.update_runner("type2", SyntheticRunner(dummy_metadata="dummy3")) + b2 = self.experiment.new_batch_trial(trial_type="type2") + b2.add_arms_and_weights(arms=arms) + self.assertEqual(b2.trial_type, "type2") + b2.run() + self.assertEqual(b2.run_metadata["dummy_metadata"], "dummy3") + + df = self.experiment.fetch_data().df + for _, row in df.iterrows(): + # Make sure proper metric present for each batch only + self.assertEqual( + row["metric_name"], "m1" if row["trial_index"] == 0 else "m2" + ) + + arm_0_slice = df.loc[df["arm_name"] == "0_0"] + self.assertNotEqual( + float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), + float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), + ) + self.assertEqual(len(df), 2 * n) + self.assertEqual(self.experiment.default_trials, {0}) + # Set 2 metrics to be equal + self.experiment.update_tracking_metric( + BraninMetric("m2", ["x1", "x2"]), trial_type="type2" + ) + df = self.experiment.fetch_data().df + arm_0_slice = df.loc[df["arm_name"] == "0_0"] + self.assertAlmostEqual( + float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), + float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), + places=10, + ) + + def test_Eq(self) -> None: + exp2 = get_multi_type_experiment() + + # Should be equal to start + self.assertTrue(self.experiment == exp2) + + self.experiment.add_tracking_metric( + BraninMetric("m3", ["x2", "x1"]), + trial_type="type1", + ) + + # Test different set of metrics + self.assertFalse(self.experiment == exp2) + + exp2.add_tracking_metric( + BraninMetric("m3", ["x2", "x1"]), + trial_type="type2", + ) + + # Test different metric definitions + self.assertFalse(self.experiment == exp2) + + exp2.update_tracking_metric( + BraninMetric("m3", ["x2", "x1"]), + trial_type="type1", + ) + + # Should be the same + self.assertTrue(self.experiment == exp2) + + exp2.remove_tracking_metric("m3") + self.assertFalse(self.experiment == exp2) + + def test_BadBehavior(self) -> None: + # Add trial type that already exists + with self.assertRaises(ValueError): + self.experiment.add_trial_type("type1", SyntheticRunner()) + + # Update runner for non-existent trial type + with self.assertRaises(ValueError): + self.experiment.update_runner("type3", SyntheticRunner()) + + # Add metric for trial_type that doesn't exist + with self.assertRaises(ValueError): + self.experiment.add_tracking_metric( + BraninMetric("m2", ["x1", "x2"]), "type3" + ) + + # Try to remove metric that doesn't exist + with self.assertRaises(ValueError): + self.experiment.remove_tracking_metric("m3") + + # Try to change optimization metric to non-primary trial type + with self.assertRaises(ValueError): + self.experiment.update_tracking_metric( + BraninMetric("m1", ["x1", "x2"]), "type2" + ) + + # Update metric definition for trial_type that doesn't exist + with self.assertRaises(ValueError): + self.experiment.update_tracking_metric( + BraninMetric("m2", ["x1", "x2"]), "type3" + ) + + # Try to get runner for trial_type that's not supported + batch = self.experiment.new_batch_trial() + batch._trial_type = "type3" # Force override trial type + with self.assertRaises(ValueError): + self.experiment.runner_for_trial_type(batch.trial_type) + + # Try making trial with unsupported trial type + with self.assertRaises(ValueError): + self.experiment.new_batch_trial(trial_type="type3") + + def test_setting_opt_config(self) -> None: + self.assertDictEqual( + self.experiment._metric_to_trial_type, {"m1": "type1", "m2": "type2"} + ) + self.experiment.optimization_config = OptimizationConfig( + Objective(BraninMetric("m3", ["x1", "x2"]), minimize=True) + ) + self.assertDictEqual( + self.experiment._metric_to_trial_type, + {"m1": "type1", "m2": "type2", "m3": "type1"}, + ) + + def test_runner_for_trial_type(self) -> None: + runner = self.experiment.runner_for_trial_type(trial_type="type1") + self.assertIs(runner, self.experiment._trial_type_to_runner["type1"]) + with self.assertRaisesRegex( + ValueError, "Trial type `invalid` is not supported." + ): + self.experiment.runner_for_trial_type(trial_type="invalid") + + def test_add_tracking_metrics(self) -> None: + type1_metrics = [ + BraninMetric("m3_type1", ["x1", "x2"]), + BraninMetric("m4_type1", ["x1", "x2"]), + ] + type2_metrics = [ + BraninMetric("m3_type2", ["x1", "x2"]), + BraninMetric("m4_type2", ["x1", "x2"]), + ] + default_type_metrics = [ + BraninMetric("m5_default_type", ["x1", "x2"]), + ] + self.experiment.add_tracking_metrics( + metrics=type1_metrics + type2_metrics + default_type_metrics, + metrics_to_trial_types={ + "m3_type1": "type1", + "m4_type1": "type1", + "m3_type2": "type2", + "m4_type2": "type2", + }, + ) + self.assertDictEqual( + self.experiment._metric_to_trial_type, + { + "m1": "type1", + "m2": "type2", + "m3_type1": "type1", + "m4_type1": "type1", + "m3_type2": "type2", + "m4_type2": "type2", + "m5_default_type": "type1", + }, + ) + + def test_stop_trial_runs_multi_type_experiment(self) -> None: + # Setup 3 trials with 2 runners + self.experiment.new_batch_trial(trial_type="type1") + self.experiment.new_batch_trial(trial_type="type2") + self.experiment.new_batch_trial(trial_type="type2") + runner1 = self.experiment.runner_for_trial_type(trial_type="type1") + runner2 = self.experiment.runner_for_trial_type(trial_type="type2") + + # Mock is needed because SyntheticRunner does not implement a 'stop' + # method + with patch.object( + runner1, "stop", return_value=None + ) as mock_runner_stop1, patch.object( + runner2, "stop", return_value=None + ) as mock_runner_stop2, patch.object( + BaseTrial, "mark_early_stopped" + ) as mock_mark_stopped: + self.experiment.stop_trial_runs( + trials=[self.experiment.trials[0], self.experiment.trials[1]] + ) + mock_runner_stop1.assert_called_once() + mock_runner_stop2.assert_called() + mock_mark_stopped.assert_called() + + +class MultiTypeExperimentUtilsTest(TestCase): + def setUp(self) -> None: + super().setUp() + self.experiment = get_multi_type_experiment() + self.experiment.new_batch_trial(trial_type="type1") + self.experiment.new_batch_trial(trial_type="type2") + + def test_filter_trials_by_type(self) -> None: + trials = self.experiment.trials.values() + self.assertEqual(len(trials), 2) + filtered = filter_trials_by_type(trials, trial_type="type1") + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0].trial_type, "type1") + filtered = filter_trials_by_type(trials, trial_type="type2") + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0].trial_type, "type2") + filtered = filter_trials_by_type(trials, trial_type="invalid") + self.assertEqual(len(filtered), 0) + filtered = filter_trials_by_type(trials, trial_type=None) + self.assertEqual(len(filtered), 2) + + def test_get_trial_indices_for_statuses(self) -> None: + self.assertEqual( + get_trial_indices_for_statuses( + experiment=self.experiment, + statuses={TrialStatus.CANDIDATE, TrialStatus.STAGED}, + trial_type="type1", + ), + {0}, + ) + self.assertEqual( + get_trial_indices_for_statuses( + experiment=self.experiment, + statuses={TrialStatus.CANDIDATE, TrialStatus.STAGED}, + trial_type="type2", + ), + {1}, + ) + self.assertEqual( + get_trial_indices_for_statuses( + experiment=self.experiment, + statuses={TrialStatus.CANDIDATE, TrialStatus.STAGED}, + ), + {0, 1}, + ) + self.experiment.trials[0].mark_running(no_runner_required=True) + self.experiment.trials[1].mark_abandoned() + self.assertEqual( + get_trial_indices_for_statuses( + experiment=self.experiment, + statuses={TrialStatus.RUNNING}, + trial_type="type1", + ), + {0}, + ) + self.assertEqual( + get_trial_indices_for_statuses( + experiment=self.experiment, + statuses={TrialStatus.ABANDONED}, + trial_type="type2", + ), + {1}, + ) diff --git a/ax/core/tests/test_multi_type_experiment.py b/ax/core/tests/test_multi_type_experiment.py deleted file mode 100644 index 6b2779dd45f..00000000000 --- a/ax/core/tests/test_multi_type_experiment.py +++ /dev/null @@ -1,291 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from unittest.mock import patch - -from ax.core.base_trial import BaseTrial, TrialStatus -from ax.core.multi_type_experiment import ( - filter_trials_by_type, - get_trial_indices_for_statuses, -) -from ax.core.objective import Objective -from ax.core.optimization_config import OptimizationConfig -from ax.metrics.branin import BraninMetric -from ax.runners.synthetic import SyntheticRunner -from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_arms, get_multi_type_experiment - - -class MultiTypeExperimentTest(TestCase): - def setUp(self) -> None: - super().setUp() - self.experiment = get_multi_type_experiment() - - def test_MTExperimentFlow(self) -> None: - self.assertTrue(self.experiment.supports_trial_type("type1")) - self.assertTrue(self.experiment.supports_trial_type("type2")) - self.assertFalse(self.experiment.supports_trial_type(None)) - - n = 10 - arms = get_branin_arms(n=n, seed=0) - - b1 = self.experiment.new_batch_trial() - b1.add_arms_and_weights(arms=arms) - self.assertEqual(b1.trial_type, "type1") - b1.run() - self.assertEqual(b1.run_metadata["dummy_metadata"], "dummy1") - - self.experiment.update_runner("type2", SyntheticRunner(dummy_metadata="dummy3")) - b2 = self.experiment.new_batch_trial(trial_type="type2") - b2.add_arms_and_weights(arms=arms) - self.assertEqual(b2.trial_type, "type2") - b2.run() - self.assertEqual(b2.run_metadata["dummy_metadata"], "dummy3") - - df = self.experiment.fetch_data().df - for _, row in df.iterrows(): - # Make sure proper metric present for each batch only - self.assertEqual( - row["metric_name"], "m1" if row["trial_index"] == 0 else "m2" - ) - - arm_0_slice = df.loc[df["arm_name"] == "0_0"] - self.assertNotEqual( - float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), - float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), - ) - self.assertEqual(len(df), 2 * n) - self.assertEqual(self.experiment.default_trials, {0}) - # Set 2 metrics to be equal - self.experiment.update_tracking_metric( - BraninMetric("m2", ["x1", "x2"]), trial_type="type2" - ) - df = self.experiment.fetch_data().df - arm_0_slice = df.loc[df["arm_name"] == "0_0"] - self.assertAlmostEqual( - float(arm_0_slice[df["trial_index"] == 0]["mean"].item()), - float(arm_0_slice[df["trial_index"] == 1]["mean"].item()), - places=10, - ) - - def test_Repr(self) -> None: - self.assertEqual(str(self.experiment), "MultiTypeExperiment(test_exp)") - - def test_Eq(self) -> None: - exp2 = get_multi_type_experiment() - - # Should be equal to start - self.assertTrue(self.experiment == exp2) - - self.experiment.add_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m4" - ) - - # Test different set of metrics - self.assertFalse(self.experiment == exp2) - - exp2.add_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m5" - ) - - # Test different metric definitions - self.assertFalse(self.experiment == exp2) - - exp2.update_tracking_metric( - BraninMetric("m3", ["x2", "x1"]), trial_type="type1", canonical_name="m4" - ) - - # Should be the same - self.assertTrue(self.experiment == exp2) - - exp2.remove_tracking_metric("m3") - self.assertFalse(self.experiment == exp2) - - def test_BadBehavior(self) -> None: - # Add trial type that already exists - with self.assertRaises(ValueError): - self.experiment.add_trial_type("type1", SyntheticRunner()) - - # Update runner for non-existent trial type - with self.assertRaises(ValueError): - self.experiment.update_runner("type3", SyntheticRunner()) - - # Add metric for trial_type that doesn't exist - with self.assertRaises(ValueError): - self.experiment.add_tracking_metric( - BraninMetric("m2", ["x1", "x2"]), "type3" - ) - - # Try to remove metric that doesn't exist - with self.assertRaises(ValueError): - self.experiment.remove_tracking_metric("m3") - - # Try to change optimization metric to non-primary trial type - with self.assertRaises(ValueError): - self.experiment.update_tracking_metric( - BraninMetric("m1", ["x1", "x2"]), "type2" - ) - - # Update metric definition for trial_type that doesn't exist - with self.assertRaises(ValueError): - self.experiment.update_tracking_metric( - BraninMetric("m2", ["x1", "x2"]), "type3" - ) - - # Try to get runner for trial_type that's not supported - batch = self.experiment.new_batch_trial() - batch._trial_type = "type3" # Force override trial type - with self.assertRaises(ValueError): - self.experiment.runner_for_trial_type(batch.trial_type) - - # Try making trial with unsupported trial type - with self.assertRaises(ValueError): - self.experiment.new_batch_trial(trial_type="type3") - - def test_setting_opt_config(self) -> None: - self.assertDictEqual( - self.experiment._metric_to_trial_type, {"m1": "type1", "m2": "type2"} - ) - self.experiment.optimization_config = OptimizationConfig( - Objective(BraninMetric("m3", ["x1", "x2"]), minimize=True) - ) - self.assertDictEqual( - self.experiment._metric_to_trial_type, - {"m1": "type1", "m2": "type2", "m3": "type1"}, - ) - - def test_runner_for_trial_type(self) -> None: - runner = self.experiment.runner_for_trial_type(trial_type="type1") - self.assertIs(runner, self.experiment._trial_type_to_runner["type1"]) - with self.assertRaisesRegex( - ValueError, "Trial type `invalid` is not supported." - ): - self.experiment.runner_for_trial_type(trial_type="invalid") - - def test_add_tracking_metrics(self) -> None: - type1_metrics = [ - BraninMetric("m3_type1", ["x1", "x2"]), - BraninMetric("m4_type1", ["x1", "x2"]), - ] - type2_metrics = [ - BraninMetric("m3_type2", ["x1", "x2"]), - BraninMetric("m4_type2", ["x1", "x2"]), - ] - default_type_metrics = [ - BraninMetric("m5_default_type", ["x1", "x2"]), - ] - self.experiment.add_tracking_metrics( - metrics=type1_metrics + type2_metrics + default_type_metrics, - metrics_to_trial_types={ - "m3_type1": "type1", - "m4_type1": "type1", - "m3_type2": "type2", - "m4_type2": "type2", - }, - ) - self.assertDictEqual( - self.experiment._metric_to_trial_type, - { - "m1": "type1", - "m2": "type2", - "m3_type1": "type1", - "m4_type1": "type1", - "m3_type2": "type2", - "m4_type2": "type2", - "m5_default_type": "type1", - }, - ) - - def test_stop_trial_runs_multi_type_experiment(self) -> None: - # Setup 3 trials with 2 runners - self.experiment.new_batch_trial(trial_type="type1") - self.experiment.new_batch_trial(trial_type="type2") - self.experiment.new_batch_trial(trial_type="type2") - runner1 = self.experiment.runner_for_trial_type(trial_type="type1") - runner2 = self.experiment.runner_for_trial_type(trial_type="type2") - - # Mock is needed because SyntheticRunner does not implement a 'stop' - # method - with patch.object( - runner1, "stop", return_value=None - ) as mock_runner_stop1, patch.object( - runner2, "stop", return_value=None - ) as mock_runner_stop2, patch.object( - BaseTrial, "mark_early_stopped" - ) as mock_mark_stopped: - self.experiment.stop_trial_runs( - trials=[self.experiment.trials[0], self.experiment.trials[1]] - ) - mock_runner_stop1.assert_called_once() - mock_runner_stop2.assert_called() - mock_mark_stopped.assert_called() - - -class MultiTypeExperimentUtilsTest(TestCase): - def setUp(self) -> None: - super().setUp() - self.experiment = get_multi_type_experiment() - self.experiment.new_batch_trial(trial_type="type1") - self.experiment.new_batch_trial(trial_type="type2") - - def test_filter_trials_by_type(self) -> None: - trials = self.experiment.trials.values() - self.assertEqual(len(trials), 2) - filtered = filter_trials_by_type(trials, trial_type="type1") - self.assertEqual(len(filtered), 1) - self.assertEqual(filtered[0].trial_type, "type1") - filtered = filter_trials_by_type(trials, trial_type="type2") - self.assertEqual(len(filtered), 1) - self.assertEqual(filtered[0].trial_type, "type2") - filtered = filter_trials_by_type(trials, trial_type="invalid") - self.assertEqual(len(filtered), 0) - filtered = filter_trials_by_type(trials, trial_type=None) - self.assertEqual(len(filtered), 2) - - def test_get_trial_indices_for_statuses(self) -> None: - self.assertEqual( - get_trial_indices_for_statuses( - experiment=self.experiment, - statuses={TrialStatus.CANDIDATE, TrialStatus.STAGED}, - trial_type="type1", - ), - {0}, - ) - self.assertEqual( - get_trial_indices_for_statuses( - experiment=self.experiment, - statuses={TrialStatus.CANDIDATE, TrialStatus.STAGED}, - trial_type="type2", - ), - {1}, - ) - self.assertEqual( - get_trial_indices_for_statuses( - experiment=self.experiment, - statuses={TrialStatus.CANDIDATE, TrialStatus.STAGED}, - ), - {0, 1}, - ) - self.experiment.trials[0].mark_running(no_runner_required=True) - self.experiment.trials[1].mark_abandoned() - self.assertEqual( - get_trial_indices_for_statuses( - experiment=self.experiment, - statuses={TrialStatus.RUNNING}, - trial_type="type1", - ), - {0}, - ) - self.assertEqual( - get_trial_indices_for_statuses( - experiment=self.experiment, - statuses={TrialStatus.ABANDONED}, - trial_type="type2", - ), - {1}, - ) diff --git a/ax/orchestration/orchestrator.py b/ax/orchestration/orchestrator.py index b5c74201371..ddf48fac1de 100644 --- a/ax/orchestration/orchestrator.py +++ b/ax/orchestration/orchestrator.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - from __future__ import annotations from collections.abc import Callable, Generator, Iterable, Mapping @@ -21,14 +20,14 @@ from ax.adapter.adapter_utils import get_fixed_features_from_experiment from ax.adapter.base import Adapter from ax.core.base_trial import BaseTrial -from ax.core.experiment import Experiment -from ax.core.generator_run import GeneratorRun -from ax.core.metric import Metric, MetricFetchE, MetricFetchResult -from ax.core.multi_type_experiment import ( +from ax.core.experiment import ( + Experiment, filter_trials_by_type, get_trial_indices_for_statuses, - MultiTypeExperiment, ) +from ax.core.generator_run import GeneratorRun +from ax.core.metric import Metric, MetricFetchE, MetricFetchResult +from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.trial_status import TrialStatus diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index 58dead781e5..35248e8355c 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -2740,7 +2740,7 @@ def setUp(self) -> None: Objective(Metric(name="branin"), minimize=True) ), default_trial_type="type1", - default_runner=None, + runner=None, name="branin_experiment_no_impl_runner_or_metrics", ) self.sobol_MBM_GS = choose_generation_strategy_legacy( diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index a02047eff3b..20c6d0f60de 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -417,7 +417,6 @@ def add_tracking_metrics( metric_names: list[str], metric_definitions: dict[str, dict[str, Any]] | None = None, metrics_to_trial_types: dict[str, str] | None = None, - canonical_names: dict[str, str] | None = None, ) -> None: """Add a list of new metrics to the experiment. @@ -435,9 +434,6 @@ def add_tracking_metrics( trial types for each metric. If provided, the metrics will be added with their respective trial types. If not provided, then the default trial type will be used. - canonical_names: A mapping from metric name (of a particular trial type) - to the metric name of the default trial type. Only applicable to - MultiTypeExperiment. """ metric_definitions = ( self.metric_definitions @@ -454,7 +450,6 @@ def add_tracking_metrics( experiment.add_tracking_metrics( metrics=metric_objects, metrics_to_trial_types=metrics_to_trial_types, - canonical_names=canonical_names, ) else: self.experiment.add_tracking_metrics(metrics=metric_objects) diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 422f813249f..bafa02ab7ef 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -908,7 +908,7 @@ def make_experiment( parameters=parameters, parameter_constraints=parameter_constraints ), default_trial_type=none_throws(default_trial_type), - default_runner=none_throws(default_runner), + runner=none_throws(default_runner), optimization_config=optimization_config, tracking_metrics=tracking_metrics, status_quo=status_quo_arm, diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index e467f0d7886..9af31218ab6 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -602,7 +602,6 @@ def multi_type_experiment_from_json( """Load AE MultiTypeExperiment from JSON.""" experiment_info = _get_experiment_info(object_json) - _metric_to_canonical_name = object_json.pop("_metric_to_canonical_name") _metric_to_trial_type = object_json.pop("_metric_to_trial_type") _trial_type_to_runner = object_from_json( object_json.pop("_trial_type_to_runner"), @@ -630,7 +629,6 @@ def multi_type_experiment_from_json( experiment = MultiTypeExperiment(**kwargs) for metric in tracking_metrics: experiment._tracking_metrics[metric.name] = metric - experiment._metric_to_canonical_name = _metric_to_canonical_name experiment._metric_to_trial_type = _metric_to_trial_type experiment._trial_type_to_runner = _trial_type_to_runner diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 6ee54fbad8a..a0be0af6bf8 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -104,7 +104,6 @@ def multi_type_experiment_to_dict(experiment: MultiTypeExperiment) -> dict[str, """Convert AE multitype experiment to a dictionary.""" multi_type_dict = { "default_trial_type": experiment._default_trial_type, - "_metric_to_canonical_name": experiment._metric_to_canonical_name, "_metric_to_trial_type": experiment._metric_to_trial_type, "_trial_type_to_runner": experiment._trial_type_to_runner, } diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 23742fbeaea..836e950811e 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -305,7 +305,7 @@ def _init_mt_experiment_from_sqa( description=experiment_sqa.description, search_space=search_space, default_trial_type=default_trial_type, - default_runner=trial_type_to_runner.get(default_trial_type), + runner=trial_type_to_runner.get(default_trial_type), optimization_config=opt_config, status_quo=status_quo, properties=properties, @@ -320,7 +320,6 @@ def _init_mt_experiment_from_sqa( experiment.add_tracking_metric( tracking_metric, trial_type=none_throws(sqa_metric.trial_type), - canonical_name=sqa_metric.canonical_name, ) return experiment diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 4fdf9ba51f7..2252b1155ba 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -220,10 +220,6 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: for metric in tracking_metrics: metric.trial_type = experiment._metric_to_trial_type[metric.name] - if metric.name in experiment._metric_to_canonical_name: - metric.canonical_name = experiment._metric_to_canonical_name[ - metric.name - ] elif experiment.runner: runners.append(self.runner_to_sqa(none_throws(experiment.runner))) properties = experiment._properties.copy() diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index e4f6485e0a9..72b45eac03c 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -813,11 +813,8 @@ def test_MTExperimentSaveAndLoad(self) -> None: loaded_experiment = load_experiment(experiment.name) self.assertEqual(loaded_experiment.default_trial_type, "type1") self.assertEqual(len(loaded_experiment._trial_type_to_runner), 2) - # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. - self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_MTExperimentSaveAndLoadSkipRunnersAndMetrics(self) -> None: @@ -829,11 +826,8 @@ def test_MTExperimentSaveAndLoadSkipRunnersAndMetrics(self) -> None: self.assertEqual(loaded_experiment.default_trial_type, "type1") self.assertIsNone(loaded_experiment._trial_type_to_runner["type1"]) self.assertIsNone(loaded_experiment._trial_type_to_runner["type2"]) - # pyre-fixme[16]: `Experiment` has no attribute `metric_to_trial_type`. self.assertEqual(loaded_experiment.metric_to_trial_type["m1"], "type1") self.assertEqual(loaded_experiment.metric_to_trial_type["m2"], "type2") - # pyre-fixme[16]: `Experiment` has no attribute `_metric_to_canonical_name`. - self.assertEqual(loaded_experiment._metric_to_canonical_name["m2"], "m1") self.assertEqual(len(loaded_experiment.trials), 2) def test_ExperimentNewTrial(self) -> None: diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 6d6dff0e638..16d18e16b6b 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -610,7 +610,7 @@ def get_multi_type_experiment( name="test_exp", search_space=get_branin_search_space(), default_trial_type="type1", - default_runner=SyntheticRunner(dummy_metadata="dummy1"), + runner=SyntheticRunner(dummy_metadata="dummy1"), optimization_config=oc, status_quo=Arm(parameters={"x1": 0.0, "x2": 0.0}), ) @@ -620,7 +620,8 @@ def get_multi_type_experiment( ) # Switch the order of variables so metric gives different results experiment.add_tracking_metric( - BraninMetric("m2", ["x2", "x1"]), trial_type="type2", canonical_name="m1" + BraninMetric("m2", ["x2", "x1"]), + trial_type="type2", ) if add_trials and add_trial_type: