Skip to content

Commit a8e35ac

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Implement UnitX.transform_experiment_data (facebook#3888)
Summary: Pull Request resolved: facebook#3888 As titled. Supports transforming `ExperimentData` with `UnitX` transform. Background: As part of the larger refactor, we will be using `ExperimentData` in place of `list[Observation]` within the `Adapter`. - The transforms will be initialized using `ExperimentData`. The `observations` input to the constructors may be deprecated once the use cases are updated. - The training data for `Adapter` will be represented with `ExperimentData` and will be transformed using `transform_experiment_data`. - For misc input / output to various `Adapter` and other methods, the `Observation / ObservationFeatures / ObservationData` objects will remain. To support these, we will retain the existing transform methods that service these objects. - Since `ExperimentData` is not planned to be used as an output of user facing methods, we do not need to untransform it. We are not planning to implement`untransform_experiment_data`. Reviewed By: esantorella Differential Revision: D76085733 fbshipit-source-id: 40a07675cfc4252f9d5699423e6e2bfea5295378
1 parent 9d151bf commit a8e35ac

File tree

2 files changed

+113
-131
lines changed

2 files changed

+113
-131
lines changed

ax/adapter/transforms/tests/test_unit_x_transform.py

Lines changed: 90 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,27 @@
99
from copy import deepcopy
1010

1111
import numpy as np
12+
from ax.adapter.base import DataLoaderConfig
13+
from ax.adapter.data_utils import extract_experiment_data
1214
from ax.adapter.transforms.unit_x import UnitX
1315
from ax.core.observation import ObservationFeatures
1416
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
1517
from ax.core.parameter_constraint import ParameterConstraint
1618
from ax.core.search_space import RobustSearchSpace, SearchSpace
1719
from ax.exceptions.core import UnsupportedError, UserInputError
1820
from ax.utils.common.testutils import TestCase
19-
from ax.utils.testing.core_stubs import get_robust_search_space
21+
from ax.utils.testing.core_stubs import (
22+
get_experiment_with_observations,
23+
get_robust_search_space,
24+
)
25+
from pandas import DataFrame
26+
from pandas.testing import assert_frame_equal
2027
from pyre_extensions import assert_is_instance
2128

2229

2330
class UnitXTransformTest(TestCase):
24-
transform_class = UnitX
25-
# pyre-fixme[4]: Attribute must be annotated.
26-
expected_c_dicts = [{"x": -1.0, "y": 1.0}, {"x": -1.0, "a": 1.0}]
27-
expected_c_bounds = [0.0, 1.0]
28-
2931
def setUp(self) -> None:
3032
super().setUp()
31-
self.target_lb = self.transform_class.target_lb
32-
self.target_range = self.transform_class.target_range
33-
self.target_ub = self.target_lb + self.target_range
3433
self.search_space = SearchSpace(
3534
parameters=[
3635
RangeParameter(
@@ -56,10 +55,7 @@ def setUp(self) -> None:
5655
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5),
5756
],
5857
)
59-
self.t = self.transform_class(
60-
search_space=self.search_space,
61-
observations=[],
62-
)
58+
self.t = UnitX(search_space=self.search_space)
6359
self.search_space_with_target = SearchSpace(
6460
parameters=[
6561
RangeParameter(
@@ -86,13 +82,7 @@ def test_TransformObservationFeatures(self) -> None:
8682
obs_ft2,
8783
[
8884
ObservationFeatures(
89-
parameters={
90-
"x": self.target_lb + self.target_range / 2.0,
91-
"y": 1.0,
92-
"z": 2,
93-
"a": 2,
94-
"b": "b",
95-
}
85+
parameters={"x": 0.5, "y": 1.0, "z": 2, "a": 2, "b": "b"}
9686
)
9787
],
9888
)
@@ -103,7 +93,7 @@ def test_TransformObservationFeatures(self) -> None:
10393
obs_ft3 = self.t.transform_observation_features(obs_ft3)
10494
self.assertEqual(
10595
obs_ft3[0],
106-
ObservationFeatures(parameters={"x": self.target_ub, "z": 2}),
96+
ObservationFeatures(parameters={"x": 1.0, "z": 2}),
10797
)
10898
obs_ft5 = self.t.transform_observation_features([ObservationFeatures({})])
10999
self.assertEqual(obs_ft5[0], ObservationFeatures({}))
@@ -114,31 +104,35 @@ def test_TransformSearchSpace(self) -> None:
114104

115105
# Parameters transformed
116106
true_bounds = {
117-
"x": (self.target_lb, 1.0),
118-
"y": (self.target_lb, 1.0),
107+
"x": (0.0, 1.0),
108+
"y": (0.0, 1.0),
119109
"z": (1.0, 2.0),
120110
"a": (1.0, 2.0),
121111
}
122112
for p_name, (l, u) in true_bounds.items():
123-
self.assertEqual(ss2.parameters[p_name].lower, l)
124-
self.assertEqual(ss2.parameters[p_name].upper, u)
125-
self.assertEqual(ss2.parameters["b"].values, ["a", "b", "c"])
113+
self.assertEqual(
114+
assert_is_instance(ss2.parameters[p_name], RangeParameter).lower, l
115+
)
116+
self.assertEqual(
117+
assert_is_instance(ss2.parameters[p_name], RangeParameter).upper, u
118+
)
119+
self.assertEqual(
120+
assert_is_instance(ss2.parameters["b"], ChoiceParameter).values,
121+
["a", "b", "c"],
122+
)
126123
self.assertEqual(len(ss2.parameters), 5)
127124
# Constraints transformed
128125
self.assertEqual(
129-
ss2.parameter_constraints[0].constraint_dict, self.expected_c_dicts[0]
126+
ss2.parameter_constraints[0].constraint_dict, {"x": -1.0, "y": 1.0}
130127
)
131-
self.assertEqual(ss2.parameter_constraints[0].bound, self.expected_c_bounds[0])
128+
self.assertEqual(ss2.parameter_constraints[0].bound, 0.0)
132129
self.assertEqual(
133-
ss2.parameter_constraints[1].constraint_dict, self.expected_c_dicts[1]
130+
ss2.parameter_constraints[1].constraint_dict, {"x": -1.0, "a": 1.0}
134131
)
135-
self.assertEqual(ss2.parameter_constraints[1].bound, self.expected_c_bounds[1])
132+
self.assertEqual(ss2.parameter_constraints[1].bound, 1.0)
136133

137134
# Test transform of target value
138-
t = self.transform_class(
139-
search_space=self.search_space_with_target,
140-
observations=[],
141-
)
135+
t = UnitX(search_space=self.search_space_with_target)
142136
t.transform_search_space(self.search_space_with_target)
143137
self.assertEqual(
144138
self.search_space_with_target.parameters["x"].target_value, 1.0
@@ -175,14 +169,8 @@ def test_TransformNewSearchSpace(self) -> None:
175169
self.t.transform_search_space(new_ss)
176170
# Parameters transformed
177171
true_bounds = {
178-
"x": [
179-
0.25 * self.target_range + self.target_lb,
180-
0.5 * self.target_range + self.target_lb,
181-
],
182-
"y": [
183-
0.25 * self.target_range + self.target_lb,
184-
1.0 * self.target_range + self.target_lb,
185-
],
172+
"x": [0.25, 0.5],
173+
"y": [0.25, 1.0],
186174
"z": [1.0, 1.5],
187175
"a": [0, 2],
188176
}
@@ -197,23 +185,16 @@ def test_TransformNewSearchSpace(self) -> None:
197185
self.assertEqual(len(new_ss.parameters), 5)
198186
# # Constraints transformed
199187
self.assertEqual(
200-
new_ss.parameter_constraints[0].constraint_dict, self.expected_c_dicts[0]
201-
)
202-
self.assertEqual(
203-
new_ss.parameter_constraints[0].bound, self.expected_c_bounds[0]
204-
)
205-
self.assertEqual(
206-
new_ss.parameter_constraints[1].constraint_dict, self.expected_c_dicts[1]
188+
new_ss.parameter_constraints[0].constraint_dict, {"x": -1.0, "y": 1.0}
207189
)
190+
self.assertEqual(new_ss.parameter_constraints[0].bound, 0.0)
208191
self.assertEqual(
209-
new_ss.parameter_constraints[1].bound, self.expected_c_bounds[1]
192+
new_ss.parameter_constraints[1].constraint_dict, {"x": -1.0, "a": 1.0}
210193
)
194+
self.assertEqual(new_ss.parameter_constraints[1].bound, 1.0)
211195

212196
# Test transform of target value
213-
t = self.transform_class(
214-
search_space=self.search_space_with_target,
215-
observations=[],
216-
)
197+
t = UnitX(search_space=self.search_space_with_target)
217198
new_search_space_with_target = SearchSpace(
218199
parameters=[
219200
RangeParameter(
@@ -227,50 +208,30 @@ def test_TransformNewSearchSpace(self) -> None:
227208
]
228209
)
229210
t.transform_search_space(new_search_space_with_target)
230-
self.assertEqual(
231-
new_search_space_with_target.parameters["x"].target_value,
232-
0.5 * self.target_range + self.target_lb,
233-
)
211+
self.assertEqual(new_search_space_with_target.parameters["x"].target_value, 0.5)
234212

235213
def test_w_robust_search_space_univariate(self) -> None:
236214
# Check that if no transforms are needed, it is untouched.
237215
for multivariate in (True, False):
238-
rss = get_robust_search_space(
239-
multivariate=multivariate,
240-
lb=self.target_lb,
241-
ub=self.target_ub,
242-
)
216+
rss = get_robust_search_space(multivariate=multivariate, lb=0.0, ub=1.0)
243217
expected = str(rss)
244-
t = self.transform_class(
245-
search_space=rss,
246-
observations=[],
247-
)
218+
t = UnitX(search_space=rss)
248219
self.assertEqual(expected, str(t.transform_search_space(rss)))
249220
# Error if distribution is multiplicative.
250221
rss = get_robust_search_space()
251222
rss.parameter_distributions[0].multiplicative = True
252-
t = self.transform_class(
253-
search_space=rss,
254-
observations=[],
255-
)
223+
t = UnitX(search_space=rss)
256224
with self.assertRaisesRegex(NotImplementedError, "multiplicative"):
257225
t.transform_search_space(rss)
258226
# Correctly transform univariate additive distributions.
259227
rss = get_robust_search_space(lb=5.0, ub=10.0)
260-
t = self.transform_class(
261-
search_space=rss,
262-
observations=[],
263-
)
228+
t = UnitX(search_space=rss)
264229
t.transform_search_space(rss)
265230
dists = rss.parameter_distributions
266-
self.assertEqual(
267-
dists[0].distribution_parameters["loc"], 0.2 * self.target_range
268-
)
269-
self.assertEqual(dists[0].distribution_parameters["scale"], self.target_range)
231+
self.assertEqual(dists[0].distribution_parameters["loc"], 0.2)
232+
self.assertEqual(dists[0].distribution_parameters["scale"], 1.0)
270233
self.assertEqual(dists[1].distribution_parameters["loc"], 0.0)
271-
self.assertEqual(
272-
dists[1].distribution_parameters["scale"], 0.2 * self.target_range
273-
)
234+
self.assertEqual(dists[1].distribution_parameters["scale"], 0.2)
274235
# Correctly transform environmental distributions.
275236
rss = get_robust_search_space(lb=5.0, ub=10.0)
276237
all_parameters = list(rss.parameters.values())
@@ -286,22 +247,19 @@ def test_w_robust_search_space_univariate(self) -> None:
286247
dist.distribution_parameters["loc"],
287248
t._normalize_value(1.0, (5.0, 10.0)),
288249
)
289-
self.assertEqual(dist.distribution_parameters["scale"], self.target_range)
250+
self.assertEqual(dist.distribution_parameters["scale"], 1.0)
290251
# Error if transform via loc / scale is not supported.
291252
rss = get_robust_search_space(use_discrete=True)
292253
rss.parameters["z"]._parameter_type = ParameterType.FLOAT
293-
t = self.transform_class(
294-
search_space=rss,
295-
observations=[],
296-
)
254+
t = UnitX(search_space=rss)
297255
with self.assertRaisesRegex(UnsupportedError, "`loc` and `scale`"):
298256
t.transform_search_space(rss)
299257

300258
def test_w_robust_search_space_multivariate(self) -> None:
301259
# Error if trying to transform non-normal multivariate distributions.
302260
rss = get_robust_search_space(multivariate=True)
303261
rss.parameter_distributions[0].distribution_class = "multivariate_t"
304-
t = self.transform_class(
262+
t = UnitX(
305263
search_space=rss,
306264
observations=[],
307265
)
@@ -310,25 +268,16 @@ def test_w_robust_search_space_multivariate(self) -> None:
310268
# Transform multivariate normal.
311269
rss = get_robust_search_space(multivariate=True)
312270
old_params = deepcopy(rss.parameter_distributions[0].distribution_parameters)
313-
t = self.transform_class(
314-
search_space=rss,
315-
observations=[],
316-
)
271+
t = UnitX(search_space=rss)
317272
t.transform_search_space(rss)
318273
new_params = rss.parameter_distributions[0].distribution_parameters
319274
self.assertIsInstance(new_params["mean"], np.ndarray)
320275
self.assertIsInstance(new_params["cov"], np.ndarray)
321276
self.assertTrue(
322-
np.allclose(
323-
new_params["mean"],
324-
np.asarray(old_params["mean"]) / 5.0 * self.target_range,
325-
)
277+
np.allclose(new_params["mean"], np.asarray(old_params["mean"]) / 5.0)
326278
)
327279
self.assertTrue(
328-
np.allclose(
329-
new_params["cov"],
330-
np.asarray(old_params["cov"]) / ((5.0 / self.target_range) ** 2),
331-
)
280+
np.allclose(new_params["cov"], np.asarray(old_params["cov"]) / 25.0)
332281
)
333282
# Transform multivariate normal environmental distribution.
334283
rss = get_robust_search_space(multivariate=True)
@@ -339,18 +288,11 @@ def test_w_robust_search_space_multivariate(self) -> None:
339288
num_samples=rss.num_samples,
340289
environmental_variables=rss_params[:2],
341290
)
342-
t = self.transform_class(
343-
search_space=rss,
344-
observations=[],
345-
)
291+
t = UnitX(search_space=rss)
346292
t.transform_search_space(rss)
347293
new_params = rss.parameter_distributions[0].distribution_parameters
348294
self.assertTrue(
349-
np.allclose(
350-
new_params["mean"],
351-
np.asarray(old_params["mean"]) / 5.0 * self.target_range
352-
+ self.target_lb,
353-
)
295+
np.allclose(new_params["mean"], np.asarray(old_params["mean"]) / 5.0)
354296
)
355297
# Errors if mean / cov are of wrong shape.
356298
rss.parameter_distributions[0].distribution_parameters["mean"] = [1.0]
@@ -360,3 +302,42 @@ def test_w_robust_search_space_multivariate(self) -> None:
360302
rss.parameter_distributions[0].distribution_parameters["cov"] = [1.0]
361303
with self.assertRaisesRegex(UserInputError, "cov"):
362304
t.transform_search_space(rss)
305+
306+
def test_transform_experiment_data(self) -> None:
307+
parameterizations = [
308+
{"x": 1.0, "y": 1.5, "z": 1.0, "a": 1, "b": "b"},
309+
{"x": 2.0, "y": 2.0, "z": 2.0, "a": 2, "b": "b"},
310+
]
311+
experiment = get_experiment_with_observations(
312+
observations=[[1.0], [2.0]],
313+
search_space=self.search_space,
314+
parameterizations=parameterizations,
315+
)
316+
experiment_data = extract_experiment_data(
317+
experiment=experiment, data_loader_config=DataLoaderConfig()
318+
)
319+
transformed_data = self.t.transform_experiment_data(
320+
experiment_data=deepcopy(experiment_data)
321+
)
322+
323+
# Check that `x` and `y` have been transformed.
324+
expected = DataFrame(
325+
index=transformed_data.arm_data.index,
326+
data={
327+
"x": [0.0, 0.5],
328+
"y": [0.5, 1.0],
329+
},
330+
columns=["x", "y"],
331+
)
332+
assert_frame_equal(transformed_data.arm_data[["x", "y"]], expected)
333+
334+
# Remaining columns are unchanged.
335+
# "z" is log-scale and "a" is in, so they're not transformed.
336+
cols = ["z", "a", "b", "metadata"]
337+
assert_frame_equal(
338+
transformed_data.arm_data[cols], experiment_data.arm_data[cols]
339+
)
340+
# Observation data is unchanged.
341+
assert_frame_equal(
342+
transformed_data.observation_data, experiment_data.observation_data
343+
)

0 commit comments

Comments
 (0)