|
16 | 16 | from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
|
17 | 17 | from ax.core.base_trial import BaseTrial, TrialStatus
|
18 | 18 | from ax.core.data import Data
|
| 19 | +from ax.core.experiment import sort_by_trial_index_and_arm_name |
19 | 20 | from ax.core.map_data import MapData
|
20 | 21 | from ax.core.map_metric import MapMetric
|
21 | 22 | from ax.core.metric import Metric
|
|
65 | 66 | get_test_map_data_experiment,
|
66 | 67 | )
|
67 | 68 | from ax.utils.testing.mock import mock_botorch_optimize
|
| 69 | +from pandas.testing import assert_frame_equal |
68 | 70 | from pyre_extensions import assert_is_instance
|
69 | 71 |
|
70 | 72 | DUMMY_RUN_METADATA_KEY = "test_run_metadata_key"
|
@@ -697,6 +699,81 @@ def test_FetchTrialsData(self) -> None:
|
697 | 699 | set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
|
698 | 700 | )
|
699 | 701 |
|
| 702 | + def test_attach_and_sort_data(self) -> None: |
| 703 | + n = 4 |
| 704 | + exp = self._setupBraninExperiment(n) |
| 705 | + batch = exp.trials[0] |
| 706 | + batch.mark_completed() |
| 707 | + self.assertEqual(exp.completed_trials, [batch]) |
| 708 | + |
| 709 | + # test sorting data |
| 710 | + unsorted_df = pd.DataFrame( |
| 711 | + { |
| 712 | + "arm_name": [ |
| 713 | + "0_0", |
| 714 | + "0_2", |
| 715 | + "0_11", |
| 716 | + "0_1", |
| 717 | + "status_quo", |
| 718 | + "1_0", |
| 719 | + "1_1", |
| 720 | + "1_2", |
| 721 | + "1_13", |
| 722 | + ], |
| 723 | + "metric_name": ["b"] * 9, |
| 724 | + "mean": list(range(1, 10)), |
| 725 | + "sem": [0.1 + i * 0.05 for i in range(9)], |
| 726 | + "trial_index": [0, 0, 0, 0, 0, 1, 1, 1, 1], |
| 727 | + } |
| 728 | + ) |
| 729 | + |
| 730 | + sorted_dfs = [] |
| 731 | + sorted_dfs.append( |
| 732 | + pd.DataFrame( |
| 733 | + { |
| 734 | + "trial_index": [0] * 5, |
| 735 | + "arm_name": [ |
| 736 | + "status_quo", |
| 737 | + "0_0", |
| 738 | + "0_1", |
| 739 | + "0_2", |
| 740 | + "0_11", |
| 741 | + ], |
| 742 | + "metric_name": ["b"] * 5, |
| 743 | + "mean": [5.0, 1.0, 4.0, 2.0, 3.0], |
| 744 | + "sem": [0.3, 0.1, 0.25, 0.15, 0.2], |
| 745 | + } |
| 746 | + ) |
| 747 | + ) |
| 748 | + |
| 749 | + sorted_dfs.append( |
| 750 | + pd.DataFrame( |
| 751 | + { |
| 752 | + "trial_index": [1] * 4, |
| 753 | + "arm_name": [ |
| 754 | + "1_0", |
| 755 | + "1_1", |
| 756 | + "1_2", |
| 757 | + "1_13", |
| 758 | + ], |
| 759 | + "metric_name": ["b"] * 4, |
| 760 | + "mean": [6.0, 7.0, 8.0, 9.0], |
| 761 | + "sem": [0.35, 0.4, 0.45, 0.5], |
| 762 | + } |
| 763 | + ) |
| 764 | + ) |
| 765 | + |
| 766 | + exp.attach_data( |
| 767 | + Data( |
| 768 | + df=unsorted_df, |
| 769 | + ) |
| 770 | + ) |
| 771 | + for trial_index in [0, 1]: |
| 772 | + assert_frame_equal( |
| 773 | + list(exp.data_by_trial[trial_index].values())[0].df, |
| 774 | + sorted_dfs[trial_index], |
| 775 | + ) |
| 776 | + |
700 | 777 | def test_immutable_search_space_and_opt_config(self) -> None:
|
701 | 778 | mutable_exp = self._setupBraninExperiment(n=5)
|
702 | 779 | self.assertFalse(mutable_exp.immutable_search_space_and_opt_config)
|
@@ -1750,3 +1827,73 @@ def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature
|
1750 | 1827 | experiment._name_and_store_arm_if_not_exists(
|
1751 | 1828 | arm=arm_2, proposed_name="different proposed name"
|
1752 | 1829 | )
|
| 1830 | + |
| 1831 | + def test_sorting_data_by_trial_index_and_arm_name(self) -> None: |
| 1832 | + # test sorting data |
| 1833 | + unsorted_df = pd.DataFrame( |
| 1834 | + { |
| 1835 | + "trial_index": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], |
| 1836 | + "arm_name": [ |
| 1837 | + "0_0", |
| 1838 | + "0_2", |
| 1839 | + "custom_arm_1", |
| 1840 | + "0_11", |
| 1841 | + "status_quo", |
| 1842 | + "0_1", |
| 1843 | + "1_0", |
| 1844 | + "custom_arm_2", |
| 1845 | + "1_1", |
| 1846 | + "status_quo", |
| 1847 | + "1_2", |
| 1848 | + "1_3", |
| 1849 | + ], |
| 1850 | + "metric_name": ["b"] * 12, |
| 1851 | + "mean": [float(x) for x in range(1, 13)], |
| 1852 | + "sem": [0.1 + i * 0.05 for i in range(12)], |
| 1853 | + } |
| 1854 | + ) |
| 1855 | + |
| 1856 | + expected_sorted_df = pd.DataFrame( |
| 1857 | + { |
| 1858 | + "trial_index": [0] * 6 + [1] * 6, |
| 1859 | + "arm_name": [ |
| 1860 | + "custom_arm_1", |
| 1861 | + "status_quo", |
| 1862 | + "0_0", |
| 1863 | + "0_1", |
| 1864 | + "0_2", |
| 1865 | + "0_11", |
| 1866 | + "custom_arm_2", |
| 1867 | + "status_quo", |
| 1868 | + "1_0", |
| 1869 | + "1_1", |
| 1870 | + "1_2", |
| 1871 | + "1_3", |
| 1872 | + ], |
| 1873 | + "metric_name": ["b"] * 12, |
| 1874 | + "mean": [3.0, 5.0, 1.0, 6.0, 2.0, 4.0, 8.0, 10.0, 7.0, 9.0, 11.0, 12.0], |
| 1875 | + "sem": [ |
| 1876 | + 0.2, |
| 1877 | + 0.3, |
| 1878 | + 0.1, |
| 1879 | + 0.35, |
| 1880 | + 0.15, |
| 1881 | + 0.25, |
| 1882 | + 0.45, |
| 1883 | + 0.55, |
| 1884 | + 0.4, |
| 1885 | + 0.5, |
| 1886 | + 0.6, |
| 1887 | + 0.65, |
| 1888 | + ], |
| 1889 | + } |
| 1890 | + ) |
| 1891 | + |
| 1892 | + sorted_df = sort_by_trial_index_and_arm_name( |
| 1893 | + df=unsorted_df, |
| 1894 | + ) |
| 1895 | + |
| 1896 | + assert_frame_equal( |
| 1897 | + sorted_df, |
| 1898 | + expected_sorted_df, |
| 1899 | + ) |
0 commit comments