Skip to content

Commit dbf7bef

Browse files
andycylmetafacebook-github-bot
authored andcommitted
Sort experiment data by trial_index, arm_name, and metrics (facebook#3885)
Summary: Pull Request resolved: facebook#3885 Arm order in metric results were displayed per the data order. This led to a random ordering of arms in the UI {F1977891605} To fix this, we want to sort the arm ordering in data. Data will be sorted by trail_index, then arm_name. Arm name will be sorted as 'custom_name' < '0_1' < '0_2' < '0_11' < '0_100' Reviewed By: lena-kashtelyan, saitcakmak Differential Revision: D74409276 fbshipit-source-id: 51d4de4e50f20ac009ed29ec58d4a417be2857a1
1 parent 2d433d2 commit dbf7bef

File tree

3 files changed

+213
-2
lines changed

3 files changed

+213
-2
lines changed

ax/core/base_trial.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from datetime import datetime, timedelta
1515
from typing import Any, TYPE_CHECKING
1616

17+
import pandas as pd
18+
1719
from ax.core.arm import Arm
1820
from ax.core.data import Data
1921
from ax.core.formatting_utils import data_and_evaluations_from_raw_data
@@ -386,9 +388,13 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data
386388
MapMetric if self.experiment.default_data_constructor == MapData else Metric
387389
)
388390

389-
return base_metric_cls._unwrap_trial_data_multi(
391+
data = base_metric_cls._unwrap_trial_data_multi(
390392
results=self.fetch_data_results(metrics=metrics, **kwargs)
391393
)
394+
if not isinstance(data, MapData):
395+
data._df = sort_by_trial_index_and_arm_name(data._df)
396+
397+
return data
392398

393399
def lookup_data(self) -> Data:
394400
"""Lookup cached data on experiment for this trial.
@@ -831,3 +837,58 @@ def _update_trial_attrs_on_clone(
831837
new_trial.mark_failed(reason=self.failed_reason)
832838
return
833839
new_trial.mark_as(self.status, unsafe=True)
840+
841+
842+
def sort_by_trial_index_and_arm_name(df: pd.DataFrame) -> pd.DataFrame:
843+
"""
844+
Sorts the dataframe by trial index and arm name. The arm names with default patterns
845+
(e.g. `0_1`, `3_11`) are sorted by trial index part (before underscore) and arm
846+
number part (after underscore) within trial index. The arm names with non-default
847+
patterns (e.g. `status_quo`, `control`, `capped_param_1`) are sorted alphabetically
848+
and will be on the top of the sorted dataframe.
849+
850+
Args:
851+
df: The DataFrame to sort.
852+
853+
Returns:
854+
The sorted DataFrame.
855+
"""
856+
857+
# Create new columns for sorting the default arm names
858+
df["is_default"] = pd.notna(df["arm_name"]) & df["arm_name"].str.count(
859+
pat=r"^\d+_\d+$"
860+
)
861+
862+
df["trial_index_part"] = float("NaN")
863+
df["arm_name_part"] = float("NaN")
864+
865+
split_arm_name = df.loc[df["is_default"], "arm_name"].str.split("_")
866+
df.loc[df["is_default"], "trial_index_part"] = split_arm_name.str.get(0).astype(int)
867+
df.loc[df["is_default"], "arm_name_part"] = split_arm_name.str.get(1).astype(int)
868+
869+
# Sort the DataFrame by the new columns (trial_index_part and arm_number_part)
870+
# for default arm names
871+
df = (
872+
df.sort_values(
873+
by=[
874+
"trial_index",
875+
"is_default",
876+
"trial_index_part",
877+
"arm_name_part",
878+
"arm_name",
879+
],
880+
inplace=False,
881+
).reset_index(drop=True)
882+
if not df.empty
883+
else df
884+
)
885+
886+
# Drop the temporary 'trial_index_part' and 'arm_number_part' columns
887+
df.drop(
888+
columns=["trial_index_part", "arm_name_part", "is_default"],
889+
# Ignore errors that occur when dropping columns that do not exist in the
890+
# dataframe.
891+
errors="ignore",
892+
inplace=True,
893+
)
894+
return df

ax/core/experiment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pandas as pd
2222
from ax.core.arm import Arm
2323
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
24-
from ax.core.base_trial import BaseTrial
24+
from ax.core.base_trial import BaseTrial, sort_by_trial_index_and_arm_name
2525
from ax.core.batch_trial import BatchTrial, LifecycleStage
2626
from ax.core.data import Data
2727
from ax.core.formatting_utils import DATA_TYPE_LOOKUP, DataType
@@ -713,6 +713,7 @@ def _lookup_or_fetch_trials_results(
713713
trials=trials,
714714
**kwargs,
715715
)
716+
716717
contains_new_data = contains_new_data or new_results_contains_new_data
717718

718719
# Merge in results
@@ -820,6 +821,8 @@ def attach_data(
820821
)
821822
cur_time_millis = current_timestamp_in_millis()
822823
for trial_index, trial_df in data.true_df.groupby(data.true_df["trial_index"]):
824+
if not isinstance(data, MapData):
825+
trial_df = sort_by_trial_index_and_arm_name(df=trial_df)
823826
# Overwrite `df` so that `data` only has current trial data.
824827
data_init_args["df"] = trial_df
825828
current_trial_data = (

ax/core/tests/test_experiment.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
1717
from ax.core.base_trial import BaseTrial, TrialStatus
1818
from ax.core.data import Data
19+
from ax.core.experiment import sort_by_trial_index_and_arm_name
1920
from ax.core.map_data import MapData
2021
from ax.core.map_metric import MapMetric
2122
from ax.core.metric import Metric
@@ -65,6 +66,7 @@
6566
get_test_map_data_experiment,
6667
)
6768
from ax.utils.testing.mock import mock_botorch_optimize
69+
from pandas.testing import assert_frame_equal
6870
from pyre_extensions import assert_is_instance
6971

7072
DUMMY_RUN_METADATA_KEY = "test_run_metadata_key"
@@ -697,6 +699,81 @@ def test_FetchTrialsData(self) -> None:
697699
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
698700
)
699701

702+
def test_attach_and_sort_data(self) -> None:
703+
n = 4
704+
exp = self._setupBraninExperiment(n)
705+
batch = exp.trials[0]
706+
batch.mark_completed()
707+
self.assertEqual(exp.completed_trials, [batch])
708+
709+
# test sorting data
710+
unsorted_df = pd.DataFrame(
711+
{
712+
"arm_name": [
713+
"0_0",
714+
"0_2",
715+
"0_11",
716+
"0_1",
717+
"status_quo",
718+
"1_0",
719+
"1_1",
720+
"1_2",
721+
"1_13",
722+
],
723+
"metric_name": ["b"] * 9,
724+
"mean": list(range(1, 10)),
725+
"sem": [0.1 + i * 0.05 for i in range(9)],
726+
"trial_index": [0, 0, 0, 0, 0, 1, 1, 1, 1],
727+
}
728+
)
729+
730+
sorted_dfs = []
731+
sorted_dfs.append(
732+
pd.DataFrame(
733+
{
734+
"trial_index": [0] * 5,
735+
"arm_name": [
736+
"status_quo",
737+
"0_0",
738+
"0_1",
739+
"0_2",
740+
"0_11",
741+
],
742+
"metric_name": ["b"] * 5,
743+
"mean": [5.0, 1.0, 4.0, 2.0, 3.0],
744+
"sem": [0.3, 0.1, 0.25, 0.15, 0.2],
745+
}
746+
)
747+
)
748+
749+
sorted_dfs.append(
750+
pd.DataFrame(
751+
{
752+
"trial_index": [1] * 4,
753+
"arm_name": [
754+
"1_0",
755+
"1_1",
756+
"1_2",
757+
"1_13",
758+
],
759+
"metric_name": ["b"] * 4,
760+
"mean": [6.0, 7.0, 8.0, 9.0],
761+
"sem": [0.35, 0.4, 0.45, 0.5],
762+
}
763+
)
764+
)
765+
766+
exp.attach_data(
767+
Data(
768+
df=unsorted_df,
769+
)
770+
)
771+
for trial_index in [0, 1]:
772+
assert_frame_equal(
773+
list(exp.data_by_trial[trial_index].values())[0].df,
774+
sorted_dfs[trial_index],
775+
)
776+
700777
def test_immutable_search_space_and_opt_config(self) -> None:
701778
mutable_exp = self._setupBraninExperiment(n=5)
702779
self.assertFalse(mutable_exp.immutable_search_space_and_opt_config)
@@ -1750,3 +1827,73 @@ def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature
17501827
experiment._name_and_store_arm_if_not_exists(
17511828
arm=arm_2, proposed_name="different proposed name"
17521829
)
1830+
1831+
def test_sorting_data_by_trial_index_and_arm_name(self) -> None:
1832+
# test sorting data
1833+
unsorted_df = pd.DataFrame(
1834+
{
1835+
"trial_index": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
1836+
"arm_name": [
1837+
"0_0",
1838+
"0_2",
1839+
"custom_arm_1",
1840+
"0_11",
1841+
"status_quo",
1842+
"0_1",
1843+
"1_0",
1844+
"custom_arm_2",
1845+
"1_1",
1846+
"status_quo",
1847+
"1_2",
1848+
"1_3",
1849+
],
1850+
"metric_name": ["b"] * 12,
1851+
"mean": [float(x) for x in range(1, 13)],
1852+
"sem": [0.1 + i * 0.05 for i in range(12)],
1853+
}
1854+
)
1855+
1856+
expected_sorted_df = pd.DataFrame(
1857+
{
1858+
"trial_index": [0] * 6 + [1] * 6,
1859+
"arm_name": [
1860+
"custom_arm_1",
1861+
"status_quo",
1862+
"0_0",
1863+
"0_1",
1864+
"0_2",
1865+
"0_11",
1866+
"custom_arm_2",
1867+
"status_quo",
1868+
"1_0",
1869+
"1_1",
1870+
"1_2",
1871+
"1_3",
1872+
],
1873+
"metric_name": ["b"] * 12,
1874+
"mean": [3.0, 5.0, 1.0, 6.0, 2.0, 4.0, 8.0, 10.0, 7.0, 9.0, 11.0, 12.0],
1875+
"sem": [
1876+
0.2,
1877+
0.3,
1878+
0.1,
1879+
0.35,
1880+
0.15,
1881+
0.25,
1882+
0.45,
1883+
0.55,
1884+
0.4,
1885+
0.5,
1886+
0.6,
1887+
0.65,
1888+
],
1889+
}
1890+
)
1891+
1892+
sorted_df = sort_by_trial_index_and_arm_name(
1893+
df=unsorted_df,
1894+
)
1895+
1896+
assert_frame_equal(
1897+
sorted_df,
1898+
expected_sorted_df,
1899+
)

0 commit comments

Comments
 (0)