Skip to content

Commit be89b7a

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Update target trial selection logic to only consider trials with data for optimization config metrics (facebook#4742)
Summary: This updates target trial selection logic in the following ways: 1. We only consider trials that either (a) have data for *all* opt config metrics or (b) have data for *all* metrics -- previously if a trial had data for some opt config metrics it passed the check, but this partial data setup causes issues downstream 2. If there is no long run trial, we fallback to short run instead of not identifying a target trial 3. filters out "stale" trials, ie trials that were completed over 10 days ago -- hitting this point would be pretty far down our priority list, but was an idea ItsMrLin initially had that i thought was really interesting 4. lastly if no trials exist after stale is filtered out, use the stale ones anyway (necessary for benchmarking) 5. if we still can't find anything, it will return none Reviewed By: ItsMrLin Differential Revision: D90089411 Privacy Context Container: L1307644
1 parent aa29dd8 commit be89b7a

File tree

8 files changed

+212
-60
lines changed

8 files changed

+212
-60
lines changed

ax/adapter/transforms/tests/test_transform_to_new_sq.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,18 +182,6 @@ def test_target_trial_index(self) -> None:
182182
)
183183

184184
self.assertEqual(t.default_trial_idx, 0)
185-
# test falling back to latest trial with SQ data
186-
with mock.patch(
187-
"ax.adapter.transforms.transform_to_new_sq.get_target_trial_index",
188-
return_value=10,
189-
):
190-
t = TransformToNewSQ(
191-
search_space=self.exp.search_space,
192-
experiment_data=experiment_data,
193-
adapter=self.adapter,
194-
)
195-
196-
self.assertEqual(t.default_trial_idx, 1)
197185

198186
def test_transform_experiment_data(self) -> None:
199187
# Create two more trials with different SQ observations.

ax/adapter/transforms/tests/test_winsorize_transform.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,13 @@ def test_relative_constraints(self) -> None:
535535
"trial_index": t.index,
536536
"metric_signature": metric_name,
537537
}
538-
for metric_name, mean, sem in (("a", 1.0, 2.0), ("b", 2.0, 4.0))
538+
# Needs data for all metrics in opt config to identify target
539+
# trial for transforms
540+
for metric_name, mean, sem in (
541+
("a", 1.0, 2.0),
542+
("b", 2.0, 4.0),
543+
("c", 3.0, 1.0),
544+
)
539545
]
540546
)
541547
)
@@ -553,7 +559,7 @@ def test_relative_constraints(self) -> None:
553559
adapter=adapter,
554560
)
555561
self.assertDictEqual(
556-
t.cutoffs, {"a": (-INF, INF), "b": (-INF, INF), "c": (0.5, INF)}
562+
t.cutoffs, {"a": (-INF, INF), "b": (-INF, INF), "c": (-3.25, INF)}
557563
)
558564
# Winsorizes with `derelativize_with_raw_status_quo`.
559565
t = Winsorize(
@@ -563,7 +569,7 @@ def test_relative_constraints(self) -> None:
563569
config={"derelativize_with_raw_status_quo": True},
564570
)
565571
self.assertDictEqual(
566-
t.cutoffs, {"a": (-INF, 4.25), "b": (-INF, 4.25), "c": (0.5, INF)}
572+
t.cutoffs, {"a": (-INF, 4.25), "b": (-INF, 4.25), "c": (-3.25, INF)}
567573
)
568574

569575
def test_transform_experiment_data(self) -> None:

ax/adapter/transforms/transform_to_new_sq.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,6 @@ def __init__(
7777
target_trial_index = get_target_trial_index(
7878
experiment=none_throws(adapter)._experiment
7979
)
80-
trials_indices_with_sq_data = self.status_quo_data_by_trial.keys()
81-
if target_trial_index not in trials_indices_with_sq_data:
82-
target_trial_index = max(trials_indices_with_sq_data)
83-
logger.warning(
84-
"No status quo data for target trial. Failing back to "
85-
f"{target_trial_index}."
86-
)
8780

8881
if target_trial_index is not None:
8982
self.default_trial_idx: int = assert_is_instance(

ax/analysis/plotly/surface/contour.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,12 @@ def _prepare_data(
325325
)
326326

327327
if relativize:
328-
target_trial_index = none_throws(get_target_trial_index(experiment=experiment))
328+
target_trial_index = none_throws(
329+
get_target_trial_index(
330+
experiment=experiment,
331+
require_data_for_all_metrics=True,
332+
)
333+
)
329334
df = relativize_data(
330335
experiment=experiment,
331336
df=df,

ax/analysis/plotly/surface/slice.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,12 @@ def _prepare_data(
286286
).sort_values(by=parameter_name)
287287

288288
if relativize:
289-
target_trial_index = none_throws(get_target_trial_index(experiment=experiment))
289+
target_trial_index = none_throws(
290+
get_target_trial_index(
291+
experiment=experiment,
292+
require_data_for_all_metrics=True,
293+
)
294+
)
290295
df = relativize_data(
291296
experiment=experiment,
292297
df=df,

ax/analysis/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ def prepare_arm_data(
171171
# Compute the trial index of the target trial both to pass as a fixed feature
172172
# during prediction if using model predictions, and to relativize against the
173173
# status quo arm from the target trial if relativizing.
174-
target_trial_index = get_target_trial_index(experiment=experiment)
174+
target_trial_index = get_target_trial_index(
175+
experiment=experiment,
176+
require_data_for_all_metrics=True,
177+
)
175178
if use_model_predictions:
176179
if adapter is None:
177180
raise UserInputError(

ax/core/tests/test_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
from copy import deepcopy
10+
from datetime import datetime, timedelta
1011
from unittest.mock import patch
1112

1213
import numpy as np
@@ -43,6 +44,7 @@
4344
from ax.utils.common.testutils import TestCase
4445
from ax.utils.testing.core_stubs import (
4546
get_branin_data,
47+
get_branin_data_batch,
4648
get_branin_experiment,
4749
get_experiment,
4850
get_hierarchical_search_space_experiment,
@@ -178,6 +180,10 @@ def setUp(self) -> None:
178180
)
179181
],
180182
)
183+
self.batch_experiment = get_branin_experiment(with_completed_trial=False)
184+
self.batch_experiment.status_quo = Arm(
185+
name="status_quo", parameters={"x1": 0.0, "x2": 0.0}
186+
)
181187

182188
def test_get_missing_metrics_by_name(self) -> None:
183189
expected = {"a": {("0_1", 1)}, "b": {("0_2", 1)}}
@@ -772,6 +778,90 @@ def test_get_target_trial_index_non_batch(self) -> None:
772778
experiment.attach_data(get_branin_data(trials=[trial]))
773779
self.assertEqual(get_target_trial_index(experiment=experiment), trial.index)
774780

781+
def test_get_target_trial_index_stale_trial_filtering(self) -> None:
782+
trials = []
783+
for days_ago in [15, 5]: # old trial (stale), new trial (recent)
784+
trial = self.batch_experiment.new_batch_trial().add_arm(
785+
self.batch_experiment.status_quo
786+
)
787+
trial.mark_completed(unsafe=True)
788+
trial._time_completed = datetime.now() - timedelta(days=days_ago)
789+
self.batch_experiment.attach_data(get_branin_data_batch(batch=trial))
790+
trials.append(trial)
791+
792+
self.assertEqual(
793+
get_target_trial_index(experiment=self.batch_experiment),
794+
trials[1].index, # newer trial
795+
)
796+
797+
def test_get_target_trial_index_all_stale_fallback(self) -> None:
798+
trial = self.batch_experiment.new_batch_trial().add_arm(
799+
self.batch_experiment.status_quo
800+
)
801+
trial.mark_completed(unsafe=True)
802+
trial._time_completed = datetime.now() - timedelta(days=15) # stale
803+
self.batch_experiment.attach_data(get_branin_data_batch(batch=trial))
804+
805+
# fallback to stale trial over none
806+
self.assertEqual(
807+
get_target_trial_index(experiment=self.batch_experiment), trial.index
808+
)
809+
810+
def test_get_target_trial_index_longrun_to_shortrun_fallback(self) -> None:
811+
# long run without data
812+
long_run_trial = self.batch_experiment.new_batch_trial(
813+
trial_type=Keys.LONG_RUN
814+
).add_arm(self.batch_experiment.status_quo)
815+
long_run_trial.mark_running(no_runner_required=True)
816+
817+
# short run with data
818+
short_run_trial = self.batch_experiment.new_batch_trial().add_arm(
819+
self.batch_experiment.status_quo
820+
)
821+
short_run_trial.mark_running(no_runner_required=True)
822+
self.batch_experiment.attach_data(get_branin_data_batch(batch=short_run_trial))
823+
824+
# ahould fallback to short-run trial since long-run has no SQ data
825+
self.assertEqual(
826+
get_target_trial_index(experiment=self.batch_experiment),
827+
short_run_trial.index,
828+
)
829+
830+
# once long-run trial has data, should return long-run trial
831+
self.batch_experiment.attach_data(get_branin_data_batch(batch=long_run_trial))
832+
self.assertEqual(
833+
get_target_trial_index(experiment=self.batch_experiment),
834+
long_run_trial.index,
835+
)
836+
837+
def test_get_target_trial_index_opt_config_metric_filtering(self) -> None:
838+
# add tracking metric, opt config is already branin
839+
self.batch_experiment.add_tracking_metric(Metric(name="test_metric"))
840+
841+
# trial with opt config data only
842+
trial = (
843+
self.batch_experiment.new_batch_trial()
844+
.add_arm(self.batch_experiment.status_quo)
845+
.mark_running(no_runner_required=True)
846+
)
847+
self.batch_experiment.attach_data(get_branin_data_batch(batch=trial))
848+
849+
# default should pass because we'll have opt config data
850+
self.assertEqual(
851+
get_target_trial_index(
852+
experiment=self.batch_experiment, require_data_for_all_metrics=False
853+
),
854+
trial.index,
855+
)
856+
857+
# when require_data_for_all_metrics=True, should return None
858+
# because there are no trials with data for all metrics
859+
self.assertIsNone(
860+
get_target_trial_index(
861+
experiment=self.batch_experiment, require_data_for_all_metrics=True
862+
)
863+
)
864+
775865
def test_batch_trial_only_decorator(self) -> None:
776866
# Create a mock function to decorate
777867
def mock_func(trial: BatchTrial) -> None:

0 commit comments

Comments
 (0)