Skip to content

Commit f1daf09

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Implement OneHot.transform_experiment_data (facebook#3886)
Summary: Pull Request resolved: facebook#3886 As titled. Supports transforming `ExperimentData` with `OneHot` transform. Background: As part of the larger refactor, we will be using `ExperimentData` in place of `list[Observation]` within the `Adapter`. - The transforms will be initialized using `ExperimentData`. The `observations` input to the constructors may be deprecated once the use cases are updated. - The training data for `Adapter` will be represented with `ExperimentData` and will be transformed using `transform_experiment_data`. - For misc input / output to various `Adapter` and other methods, the `Observation / ObservationFeatures / ObservationData` objects will remain. To support these, we will retain the existing transform methods that service these objects. - Since `ExperimentData` is not planned to be used as an output of user facing methods, we do not need to untransform it. We are not planning to implement`untransform_experiment_data`. Reviewed By: esantorella Differential Revision: D76074577 fbshipit-source-id: 07f17b5f91a7e1ce5c6d9eb91dd034bec17d6f89
1 parent 9983105 commit f1daf09

File tree

2 files changed

+97
-10
lines changed

2 files changed

+97
-10
lines changed

ax/adapter/transforms/one_hot.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Optional, TYPE_CHECKING
1010

1111
import numpy as np
12+
import pandas as pd
1213
from ax.adapter.data_utils import ExperimentData
1314
from ax.adapter.transforms.base import Transform
1415
from ax.adapter.transforms.rounding import randomized_onehot_round, strict_onehot_round
@@ -207,3 +208,35 @@ def untransform_observation_features(
207208
)
208209
obsf.parameters[p_name] = val
209210
return observation_features
211+
212+
def transform_experiment_data(
213+
self, experiment_data: ExperimentData
214+
) -> ExperimentData:
215+
arm_data = experiment_data.arm_data
216+
for p_name, values in self.encoded_values.items():
217+
# First, replace values with 0, 1, 2, so that column names are as expected.
218+
arm_data = arm_data.replace(
219+
to_replace={p_name: {v: i for i, v in enumerate(values)}}
220+
).astype({p_name: int})
221+
222+
if len(values) == 2:
223+
# Handle the special case. Only need to rename the column.
224+
arm_data = arm_data.rename(columns={p_name: p_name + OH_PARAM_INFIX})
225+
else:
226+
# Use get_dummies to one-hot encode the column.
227+
arm_data = pd.get_dummies(
228+
arm_data,
229+
columns=[p_name],
230+
prefix=p_name + OH_PARAM_INFIX,
231+
# Could be int, but using float to match the parameter type.
232+
dtype=float,
233+
)
234+
# Make sure all expected columns are present, even if there is no
235+
# corresponding value in the data.
236+
for i in range(len(values)):
237+
if f"{p_name}{OH_PARAM_INFIX}_{i}" not in arm_data:
238+
arm_data[f"{p_name}{OH_PARAM_INFIX}_{i}"] = 0.0
239+
240+
return ExperimentData(
241+
arm_data=arm_data, observation_data=experiment_data.observation_data
242+
)

ax/adapter/transforms/tests/test_one_hot_transform.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,22 @@
88

99
from copy import deepcopy
1010

11-
from ax.adapter.transforms.one_hot import OH_PARAM_INFIX, OneHot
11+
from ax.adapter.base import DataLoaderConfig
12+
13+
from ax.adapter.data_utils import extract_experiment_data
1214

15+
from ax.adapter.transforms.one_hot import OH_PARAM_INFIX, OneHot
1316
from ax.core.observation import ObservationFeatures
1417
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
1518
from ax.core.parameter_constraint import ParameterConstraint
1619
from ax.core.search_space import RobustSearchSpace, SearchSpace
1720
from ax.utils.common.testutils import TestCase
18-
from ax.utils.testing.core_stubs import get_robust_search_space
21+
from ax.utils.testing.core_stubs import (
22+
get_experiment_with_observations,
23+
get_robust_search_space,
24+
)
25+
from pandas import DataFrame
26+
from pandas.testing import assert_frame_equal
1927

2028

2129
class OneHotTransformTest(TestCase):
@@ -34,9 +42,7 @@ def setUp(self) -> None:
3442
"b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
3543
),
3644
ChoiceParameter(
37-
"c",
38-
parameter_type=ParameterType.BOOL,
39-
values=[True, False],
45+
"c", parameter_type=ParameterType.BOOL, values=[True, False]
4046
),
4147
ChoiceParameter(
4248
"d",
@@ -49,13 +55,9 @@ def setUp(self) -> None:
4955
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5)
5056
],
5157
)
52-
self.t = OneHot(
53-
search_space=self.search_space,
54-
observations=[],
55-
)
58+
self.t = OneHot(search_space=self.search_space)
5659
self.t2 = OneHot(
5760
search_space=self.search_space,
58-
observations=[],
5961
config={"rounding": "randomized"},
6062
)
6163

@@ -255,3 +257,55 @@ def test_heterogeneous_search_space(self) -> None:
255257
]
256258
untf_obs = self.t.untransform_observation_features(obs_ft)
257259
self.assertFalse(any(obs.parameters.get("b") == "b" for obs in untf_obs))
260+
261+
def test_transform_experiment_data(self) -> None:
262+
parameterizations = [
263+
{"x": 2.2, "a": 2, "b": "b", "c": False, "d": 10.0},
264+
{"x": 1.2, "a": 2, "b": "a", "c": False, "d": 100.0},
265+
]
266+
experiment = get_experiment_with_observations(
267+
observations=[[1.0], [2.0]],
268+
search_space=self.search_space,
269+
parameterizations=parameterizations,
270+
)
271+
experiment_data = extract_experiment_data(
272+
experiment=experiment, data_loader_config=DataLoaderConfig()
273+
)
274+
transformed_data = self.t.transform_experiment_data(
275+
experiment_data=deepcopy(experiment_data)
276+
)
277+
278+
# Check that only "b" has been transformed and column names are as expected.
279+
base_columns = ["x", "a", "c", "d", "metadata"]
280+
transformed_columns = [
281+
"b" + OH_PARAM_INFIX + "_0",
282+
"b" + OH_PARAM_INFIX + "_1",
283+
"b" + OH_PARAM_INFIX + "_2",
284+
]
285+
self.assertEqual(
286+
set(transformed_data.arm_data),
287+
{*base_columns, *transformed_columns},
288+
)
289+
290+
# Untransformed columns are same as before.
291+
assert_frame_equal(
292+
transformed_data.arm_data[base_columns],
293+
experiment_data.arm_data[base_columns],
294+
)
295+
# Observation data is unchanged.
296+
assert_frame_equal(
297+
transformed_data.observation_data, experiment_data.observation_data
298+
)
299+
300+
# Transformed columns have correct values.
301+
expected_columns = DataFrame(
302+
index=transformed_data.arm_data.index,
303+
data=[
304+
[0.0, 1.0, 0.0],
305+
[1.0, 0.0, 0.0],
306+
],
307+
columns=transformed_columns,
308+
)
309+
assert_frame_equal(
310+
transformed_data.arm_data[transformed_columns], expected_columns
311+
)

0 commit comments

Comments
 (0)