Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 205 additions & 30 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import inspect

import logging
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -150,10 +149,13 @@ def __init__(
self._trials: dict[int, BaseTrial] = {}
self._properties: dict[str, Any] = properties or {}

# Specifies which trial type each metric belongs to
self._metric_to_trial_type: dict[str, str | None] = {}

# Initialize trial type to runner mapping
self._default_trial_type = default_trial_type
self._default_trial_type: str | None = default_trial_type
self._trial_type_to_runner: dict[str | None, Runner | None] = {
default_trial_type: runner
self._default_trial_type: runner
}
# Used to keep track of whether any trials on the experiment
# specify a TTL. Since trials need to be checked for their TTL's
Expand Down Expand Up @@ -417,13 +419,13 @@ def runner(self, runner: Runner | None) -> None:
if runner is not None:
self._trial_type_to_runner[self._default_trial_type] = runner
else:
self._trial_type_to_runner = {None: None}
self._trial_type_to_runner = {self._default_trial_type: None}

@runner.deleter
def runner(self) -> None:
"""Delete the runner."""
self._runner = None
self._trial_type_to_runner = {None: None}
self._trial_type_to_runner = {self._default_trial_type: None}

@property
def parameters(self) -> dict[str, Parameter]:
Expand Down Expand Up @@ -493,6 +495,11 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
for metric_name in optimization_config.metrics.keys():
if metric_name in self._tracking_metrics:
self.remove_tracking_metric(metric_name)

# Optimization config metrics are required to be the default trial type
# currently. TODO: remove that restriction (T202797235)
self._metric_to_trial_type[metric_name] = self.default_trial_type

# add metrics from the previous optimization config that are not in the new
# optimization config as tracking metrics
prev_optimization_config = self._optimization_config
Expand Down Expand Up @@ -554,11 +561,16 @@ def immutable_search_space_and_opt_config(self) -> bool:
def tracking_metrics(self) -> list[Metric]:
return list(self._tracking_metrics.values())

def add_tracking_metric(self, metric: Metric) -> Experiment:
def add_tracking_metric(
self,
metric: Metric,
trial_type: str | None = None,
) -> Experiment:
"""Add a new metric to the experiment.

Args:
metric: Metric to be added.
trial_type: The trial type for which this metric is used.
"""
if metric.name in self._tracking_metrics:
raise ValueError(
Expand All @@ -574,34 +586,72 @@ def add_tracking_metric(self, metric: Metric) -> Experiment:
"before adding it to tracking metrics."
)

if trial_type is None:
trial_type = self._default_trial_type

self._tracking_metrics[metric.name] = metric
self._metric_to_trial_type[metric.name] = trial_type

return self

def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
def add_tracking_metrics(
self,
metrics: list[Metric],
metrics_to_trial_types: dict[str, str] | None = None,
) -> Experiment:
"""Add a list of new metrics to the experiment.

If any of the metrics are already defined on the experiment,
we raise an error and don't add any of them to the experiment

Args:
metrics: Metrics to be added.
metrics_to_trial_types: The mapping from metric names to corresponding
trial types for each metric. If provided, the metrics will be
added to their trial types. If not provided, then the default
trial type will be used.
"""
# Before setting any metrics, we validate none are already on
# the experiment
metrics_to_trial_types = metrics_to_trial_types or {}

for metric in metrics:
self.add_tracking_metric(metric)
self.add_tracking_metric(
metric=metric,
trial_type=metrics_to_trial_types.get(metric.name),
)
return self

def update_tracking_metric(self, metric: Metric) -> Experiment:
def update_tracking_metric(
self,
metric: Metric,
trial_type: str | None = None,
) -> Experiment:
"""Redefine a metric that already exists on the experiment.

Args:
metric: New metric definition.
trial_type: The trial type for which this metric is used.
"""
if trial_type is None:
trial_type = self._default_trial_type

oc = self.optimization_config
oc_metrics = oc.metrics if oc else []
if metric.name in oc_metrics and trial_type != self._default_trial_type:
raise ValueError(
f"Metric `{metric.name}` must remain a "
f"`{self._default_trial_type}` metric because it is part of the "
"optimization_config."
)

if not self.supports_trial_type(trial_type):
raise ValueError(f"`{trial_type}` is not a supported trial type.")

if metric.name not in self._tracking_metrics:
raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.")

self._tracking_metrics[metric.name] = metric
self._metric_to_trial_type[metric.name] = trial_type

return self

def remove_tracking_metric(self, metric_name: str) -> Experiment:
Expand All @@ -614,6 +664,8 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment:
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")

del self._tracking_metrics[metric_name]
del self._metric_to_trial_type[metric_name]

return self

@property
Expand Down Expand Up @@ -777,6 +829,21 @@ def fetch_data(
Returns:
Data for the experiment.
"""
if self.is_multi_type:
# TODO: make this more efficient for fetching
# data for multiple trials of the same type
# by overriding Experiment._lookup_or_fetch_trials_results
return Data.from_multiple_data(
[
(
trial.fetch_data(**kwargs, metrics=metrics)
if trial.status.expecting_data
else Data()
)
for trial in self.trials.values()
]
)

results = self._lookup_or_fetch_trials_results(
trials=list(self.trials.values())
if trial_indices is None
Expand Down Expand Up @@ -853,8 +920,16 @@ def _fetch_trial_data(
) -> dict[str, MetricFetchResult]:
trial = self.trials[trial_index]

metrics_for_trial_type = [
metric
for metric in metrics or self.metrics.values()
if self.metric_to_trial_type[metric.name] == trial.trial_type
]

trial_data = self._lookup_or_fetch_trials_results(
trials=[trial], metrics=metrics, **kwargs
trials=[trial],
metrics=metrics_for_trial_type,
**kwargs,
)

if trial_index in trial_data:
Expand Down Expand Up @@ -1566,19 +1641,43 @@ def __repr__(self) -> str:
return self.__class__.__name__ + f"({self._name})"

# --- MultiTypeExperiment convenience functions ---
#
# Certain functionalities have special behavior for multi-type experiments.
# This defines the base behavior for regular experiments that will be
# overridden in the MultiTypeExperiment class.
# A canonical use case for this is tuning a large production system
# with limited evaluation budget and a simulator which approximates
# evaluations on the main system. Trial deployment and data fetching
# is separate for the two systems, but the final data is combined and
# fed into multi-task models.

@property
def is_multi_type(self) -> bool:
"""Whether this Experiment contains more than one trial type."""
return len(self._trial_type_to_runner) > 1

@property
def default_trial_type(self) -> str | None:
"""Default trial type assigned to trials in this experiment.
"""Default trial type assigned to trials in this experiment."""
return self._default_trial_type

In the base experiment class this is always None. For experiments
with multiple trial types, use the MultiTypeExperiment class.
@property
def default_trials(self) -> set[int]:
"""Return the indicies for trials of the default type."""
return {
idx
for idx, trial in self.trials.items()
if trial.trial_type == self.default_trial_type
}

def add_trial_type(self, trial_type: str, runner: Runner) -> "Experiment":
"""Add a new trial_type to be supported by this experiment.

Args:
trial_type: The new trial_type to be added.
runner: The default runner for trials of this type.
"""
return self._default_trial_type
if self.supports_trial_type(trial_type):
raise ValueError(f"Experiment already contains trial_type `{trial_type}`")

self._trial_type_to_runner[trial_type] = runner
return self

def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
"""The default runner to use for a given trial type.
Expand All @@ -1591,20 +1690,55 @@ def runner_for_trial_type(self, trial_type: str | None) -> Runner | None:
return self.runner # return the default runner
return runner

def update_runner(self, trial_type: str, runner: Runner) -> "Experiment":
"""Update the default runner for an existing trial_type.

Args:
trial_type: The new trial_type to be added.
runner: The new runner for trials of this type.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Experiment does not contain trial_type `{trial_type}`")

self._trial_type_to_runner[trial_type] = runner
self._runner = runner
return self

@property
def metric_to_trial_type(self) -> dict[str, str]:
"""Map metrics to trial types.

Adds in default trial type for OC metrics to custom defined trial types..
"""
if self.optimization_config is not None:
opt_config_types = {
metric_name: self.default_trial_type
for metric_name in self.optimization_config.metrics.keys()
}
else:
opt_config_types = {}

return {**opt_config_types, **self._metric_to_trial_type}

def metrics_for_trial_type(self, trial_type: str) -> list[Metric]:
"""The default runner to use for a given trial type.

Looks up the appropriate runner for this trial type in the trial_type_to_runner.
"""
if not self.supports_trial_type(trial_type):
raise ValueError(f"Trial type `{trial_type}` is not supported.")
return [
self.metrics[metric_name]
for metric_name, metric_trial_type in self._metric_to_trial_type.items()
if metric_trial_type == trial_type
]

def supports_trial_type(self, trial_type: str | None) -> bool:
"""Whether this experiment allows trials of the given type.

The base experiment class only supports None. For experiments
with multiple trial types, use the MultiTypeExperiment class.
Only trial types defined in the trial_type_to_runner are allowed.
"""
return (
trial_type is None
# We temporarily allow "short run" and "long run" trial
# types in single-type experiments during development of
# a new ``GenerationStrategy`` that needs them.
or trial_type == Keys.SHORT_RUN
or trial_type == Keys.LONG_RUN
)
return trial_type in self._trial_type_to_runner.keys()

def attach_trial(
self,
Expand Down Expand Up @@ -2206,3 +2340,44 @@ def add_arm_and_prevent_naming_collision(
stacklevel=2,
)
new_trial.add_arm(none_throws(old_trial.arm).clone(clear_name=True))


def filter_trials_by_type(
trials: Sequence[BaseTrial], trial_type: str | None
) -> list[BaseTrial]:
"""Filter trials by trial type if provided.

This filters trials by trial type if the experiment is a
MultiTypeExperiment.

Args:
trials: Trials to filter.

Returns:
Filtered trials.
"""
if trial_type is not None:
return [t for t in trials if t.trial_type == trial_type]
return list(trials)


def get_trial_indices_for_statuses(
experiment: Experiment, statuses: set[TrialStatus], trial_type: str | None = None
) -> set[int]:
"""Get trial indices for a set of statuses.

Args:
statuses: Set of statuses to get trial indices for.

Returns:
Set of trial indices for the given statuses.
"""
return {
i
for i, t in experiment.trials.items()
if (t.status in statuses)
and (
(trial_type is None)
or ((trial_type is not None) and (t.trial_type == trial_type))
)
}
Loading
Loading