|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | from copy import deepcopy |
| 10 | +from datetime import datetime, timedelta |
10 | 11 | from unittest.mock import patch |
11 | 12 |
|
12 | 13 | import numpy as np |
|
43 | 44 | from ax.utils.common.testutils import TestCase |
44 | 45 | from ax.utils.testing.core_stubs import ( |
45 | 46 | get_branin_data, |
| 47 | + get_branin_data_batch, |
46 | 48 | get_branin_experiment, |
47 | 49 | get_experiment, |
48 | 50 | get_hierarchical_search_space_experiment, |
@@ -178,6 +180,10 @@ def setUp(self) -> None: |
178 | 180 | ) |
179 | 181 | ], |
180 | 182 | ) |
| 183 | + self.batch_experiment = get_branin_experiment(with_completed_trial=False) |
| 184 | + self.batch_experiment.status_quo = Arm( |
| 185 | + name="status_quo", parameters={"x1": 0.0, "x2": 0.0} |
| 186 | + ) |
181 | 187 |
|
182 | 188 | def test_get_missing_metrics_by_name(self) -> None: |
183 | 189 | expected = {"a": {("0_1", 1)}, "b": {("0_2", 1)}} |
@@ -772,6 +778,90 @@ def test_get_target_trial_index_non_batch(self) -> None: |
772 | 778 | experiment.attach_data(get_branin_data(trials=[trial])) |
773 | 779 | self.assertEqual(get_target_trial_index(experiment=experiment), trial.index) |
774 | 780 |
|
| 781 | + def test_get_target_trial_index_stale_trial_filtering(self) -> None: |
| 782 | + trials = [] |
| 783 | + for days_ago in [15, 5]: # old trial (stale), new trial (recent) |
| 784 | + trial = self.batch_experiment.new_batch_trial().add_arm( |
| 785 | + self.batch_experiment.status_quo |
| 786 | + ) |
| 787 | + trial.mark_completed(unsafe=True) |
| 788 | + trial._time_completed = datetime.now() - timedelta(days=days_ago) |
| 789 | + self.batch_experiment.attach_data(get_branin_data_batch(batch=trial)) |
| 790 | + trials.append(trial) |
| 791 | + |
| 792 | + self.assertEqual( |
| 793 | + get_target_trial_index(experiment=self.batch_experiment), |
| 794 | + trials[1].index, # newer trial |
| 795 | + ) |
| 796 | + |
| 797 | + def test_get_target_trial_index_all_stale_fallback(self) -> None: |
| 798 | + trial = self.batch_experiment.new_batch_trial().add_arm( |
| 799 | + self.batch_experiment.status_quo |
| 800 | + ) |
| 801 | + trial.mark_completed(unsafe=True) |
| 802 | + trial._time_completed = datetime.now() - timedelta(days=15) # stale |
| 803 | + self.batch_experiment.attach_data(get_branin_data_batch(batch=trial)) |
| 804 | + |
| 805 | + # fallback to stale trial over none |
| 806 | + self.assertEqual( |
| 807 | + get_target_trial_index(experiment=self.batch_experiment), trial.index |
| 808 | + ) |
| 809 | + |
| 810 | + def test_get_target_trial_index_longrun_to_shortrun_fallback(self) -> None: |
| 811 | + # long run without data |
| 812 | + long_run_trial = self.batch_experiment.new_batch_trial( |
| 813 | + trial_type=Keys.LONG_RUN |
| 814 | + ).add_arm(self.batch_experiment.status_quo) |
| 815 | + long_run_trial.mark_running(no_runner_required=True) |
| 816 | + |
| 817 | + # short run with data |
| 818 | + short_run_trial = self.batch_experiment.new_batch_trial().add_arm( |
| 819 | + self.batch_experiment.status_quo |
| 820 | + ) |
| 821 | + short_run_trial.mark_running(no_runner_required=True) |
| 822 | + self.batch_experiment.attach_data(get_branin_data_batch(batch=short_run_trial)) |
| 823 | + |
| 824 | + # ahould fallback to short-run trial since long-run has no SQ data |
| 825 | + self.assertEqual( |
| 826 | + get_target_trial_index(experiment=self.batch_experiment), |
| 827 | + short_run_trial.index, |
| 828 | + ) |
| 829 | + |
| 830 | + # once long-run trial has data, should return long-run trial |
| 831 | + self.batch_experiment.attach_data(get_branin_data_batch(batch=long_run_trial)) |
| 832 | + self.assertEqual( |
| 833 | + get_target_trial_index(experiment=self.batch_experiment), |
| 834 | + long_run_trial.index, |
| 835 | + ) |
| 836 | + |
| 837 | + def test_get_target_trial_index_opt_config_metric_filtering(self) -> None: |
| 838 | + # add tracking metric, opt config is already branin |
| 839 | + self.batch_experiment.add_tracking_metric(Metric(name="test_metric")) |
| 840 | + |
| 841 | + # trial with opt config data only |
| 842 | + trial = ( |
| 843 | + self.batch_experiment.new_batch_trial() |
| 844 | + .add_arm(self.batch_experiment.status_quo) |
| 845 | + .mark_running(no_runner_required=True) |
| 846 | + ) |
| 847 | + self.batch_experiment.attach_data(get_branin_data_batch(batch=trial)) |
| 848 | + |
| 849 | + # default should pass because we'll have opt config data |
| 850 | + self.assertEqual( |
| 851 | + get_target_trial_index( |
| 852 | + experiment=self.batch_experiment, require_data_for_all_metrics=False |
| 853 | + ), |
| 854 | + trial.index, |
| 855 | + ) |
| 856 | + |
| 857 | + # when require_data_for_all_metrics=True, should return None |
| 858 | + # because there are no trials with data for all metrics |
| 859 | + self.assertIsNone( |
| 860 | + get_target_trial_index( |
| 861 | + experiment=self.batch_experiment, require_data_for_all_metrics=True |
| 862 | + ) |
| 863 | + ) |
| 864 | + |
775 | 865 | def test_batch_trial_only_decorator(self) -> None: |
776 | 866 | # Create a mock function to decorate |
777 | 867 | def mock_func(trial: BatchTrial) -> None: |
|
0 commit comments