9
9
from copy import deepcopy
10
10
11
11
import numpy as np
12
+ from ax .adapter .base import DataLoaderConfig
13
+ from ax .adapter .data_utils import extract_experiment_data
12
14
from ax .adapter .transforms .unit_x import UnitX
13
15
from ax .core .observation import ObservationFeatures
14
16
from ax .core .parameter import ChoiceParameter , ParameterType , RangeParameter
15
17
from ax .core .parameter_constraint import ParameterConstraint
16
18
from ax .core .search_space import RobustSearchSpace , SearchSpace
17
19
from ax .exceptions .core import UnsupportedError , UserInputError
18
20
from ax .utils .common .testutils import TestCase
19
- from ax .utils .testing .core_stubs import get_robust_search_space
21
+ from ax .utils .testing .core_stubs import (
22
+ get_experiment_with_observations ,
23
+ get_robust_search_space ,
24
+ )
25
+ from pandas import DataFrame
26
+ from pandas .testing import assert_frame_equal
20
27
from pyre_extensions import assert_is_instance
21
28
22
29
23
30
class UnitXTransformTest (TestCase ):
24
- transform_class = UnitX
25
- # pyre-fixme[4]: Attribute must be annotated.
26
- expected_c_dicts = [{"x" : - 1.0 , "y" : 1.0 }, {"x" : - 1.0 , "a" : 1.0 }]
27
- expected_c_bounds = [0.0 , 1.0 ]
28
-
29
31
def setUp (self ) -> None :
30
32
super ().setUp ()
31
- self .target_lb = self .transform_class .target_lb
32
- self .target_range = self .transform_class .target_range
33
- self .target_ub = self .target_lb + self .target_range
34
33
self .search_space = SearchSpace (
35
34
parameters = [
36
35
RangeParameter (
@@ -56,10 +55,7 @@ def setUp(self) -> None:
56
55
ParameterConstraint (constraint_dict = {"x" : - 0.5 , "a" : 1 }, bound = 0.5 ),
57
56
],
58
57
)
59
- self .t = self .transform_class (
60
- search_space = self .search_space ,
61
- observations = [],
62
- )
58
+ self .t = UnitX (search_space = self .search_space )
63
59
self .search_space_with_target = SearchSpace (
64
60
parameters = [
65
61
RangeParameter (
@@ -86,13 +82,7 @@ def test_TransformObservationFeatures(self) -> None:
86
82
obs_ft2 ,
87
83
[
88
84
ObservationFeatures (
89
- parameters = {
90
- "x" : self .target_lb + self .target_range / 2.0 ,
91
- "y" : 1.0 ,
92
- "z" : 2 ,
93
- "a" : 2 ,
94
- "b" : "b" ,
95
- }
85
+ parameters = {"x" : 0.5 , "y" : 1.0 , "z" : 2 , "a" : 2 , "b" : "b" }
96
86
)
97
87
],
98
88
)
@@ -103,7 +93,7 @@ def test_TransformObservationFeatures(self) -> None:
103
93
obs_ft3 = self .t .transform_observation_features (obs_ft3 )
104
94
self .assertEqual (
105
95
obs_ft3 [0 ],
106
- ObservationFeatures (parameters = {"x" : self . target_ub , "z" : 2 }),
96
+ ObservationFeatures (parameters = {"x" : 1.0 , "z" : 2 }),
107
97
)
108
98
obs_ft5 = self .t .transform_observation_features ([ObservationFeatures ({})])
109
99
self .assertEqual (obs_ft5 [0 ], ObservationFeatures ({}))
@@ -114,31 +104,35 @@ def test_TransformSearchSpace(self) -> None:
114
104
115
105
# Parameters transformed
116
106
true_bounds = {
117
- "x" : (self . target_lb , 1.0 ),
118
- "y" : (self . target_lb , 1.0 ),
107
+ "x" : (0.0 , 1.0 ),
108
+ "y" : (0.0 , 1.0 ),
119
109
"z" : (1.0 , 2.0 ),
120
110
"a" : (1.0 , 2.0 ),
121
111
}
122
112
for p_name , (l , u ) in true_bounds .items ():
123
- self .assertEqual (ss2 .parameters [p_name ].lower , l )
124
- self .assertEqual (ss2 .parameters [p_name ].upper , u )
125
- self .assertEqual (ss2 .parameters ["b" ].values , ["a" , "b" , "c" ])
113
+ self .assertEqual (
114
+ assert_is_instance (ss2 .parameters [p_name ], RangeParameter ).lower , l
115
+ )
116
+ self .assertEqual (
117
+ assert_is_instance (ss2 .parameters [p_name ], RangeParameter ).upper , u
118
+ )
119
+ self .assertEqual (
120
+ assert_is_instance (ss2 .parameters ["b" ], ChoiceParameter ).values ,
121
+ ["a" , "b" , "c" ],
122
+ )
126
123
self .assertEqual (len (ss2 .parameters ), 5 )
127
124
# Constraints transformed
128
125
self .assertEqual (
129
- ss2 .parameter_constraints [0 ].constraint_dict , self . expected_c_dicts [ 0 ]
126
+ ss2 .parameter_constraints [0 ].constraint_dict , { "x" : - 1.0 , "y" : 1.0 }
130
127
)
131
- self .assertEqual (ss2 .parameter_constraints [0 ].bound , self . expected_c_bounds [ 0 ] )
128
+ self .assertEqual (ss2 .parameter_constraints [0 ].bound , 0.0 )
132
129
self .assertEqual (
133
- ss2 .parameter_constraints [1 ].constraint_dict , self . expected_c_dicts [ 1 ]
130
+ ss2 .parameter_constraints [1 ].constraint_dict , { "x" : - 1.0 , "a" : 1.0 }
134
131
)
135
- self .assertEqual (ss2 .parameter_constraints [1 ].bound , self . expected_c_bounds [ 1 ] )
132
+ self .assertEqual (ss2 .parameter_constraints [1 ].bound , 1.0 )
136
133
137
134
# Test transform of target value
138
- t = self .transform_class (
139
- search_space = self .search_space_with_target ,
140
- observations = [],
141
- )
135
+ t = UnitX (search_space = self .search_space_with_target )
142
136
t .transform_search_space (self .search_space_with_target )
143
137
self .assertEqual (
144
138
self .search_space_with_target .parameters ["x" ].target_value , 1.0
@@ -175,14 +169,8 @@ def test_TransformNewSearchSpace(self) -> None:
175
169
self .t .transform_search_space (new_ss )
176
170
# Parameters transformed
177
171
true_bounds = {
178
- "x" : [
179
- 0.25 * self .target_range + self .target_lb ,
180
- 0.5 * self .target_range + self .target_lb ,
181
- ],
182
- "y" : [
183
- 0.25 * self .target_range + self .target_lb ,
184
- 1.0 * self .target_range + self .target_lb ,
185
- ],
172
+ "x" : [0.25 , 0.5 ],
173
+ "y" : [0.25 , 1.0 ],
186
174
"z" : [1.0 , 1.5 ],
187
175
"a" : [0 , 2 ],
188
176
}
@@ -197,23 +185,16 @@ def test_TransformNewSearchSpace(self) -> None:
197
185
self .assertEqual (len (new_ss .parameters ), 5 )
198
186
# # Constraints transformed
199
187
self .assertEqual (
200
- new_ss .parameter_constraints [0 ].constraint_dict , self .expected_c_dicts [0 ]
201
- )
202
- self .assertEqual (
203
- new_ss .parameter_constraints [0 ].bound , self .expected_c_bounds [0 ]
204
- )
205
- self .assertEqual (
206
- new_ss .parameter_constraints [1 ].constraint_dict , self .expected_c_dicts [1 ]
188
+ new_ss .parameter_constraints [0 ].constraint_dict , {"x" : - 1.0 , "y" : 1.0 }
207
189
)
190
+ self .assertEqual (new_ss .parameter_constraints [0 ].bound , 0.0 )
208
191
self .assertEqual (
209
- new_ss .parameter_constraints [1 ].bound , self . expected_c_bounds [ 1 ]
192
+ new_ss .parameter_constraints [1 ].constraint_dict , { "x" : - 1.0 , "a" : 1.0 }
210
193
)
194
+ self .assertEqual (new_ss .parameter_constraints [1 ].bound , 1.0 )
211
195
212
196
# Test transform of target value
213
- t = self .transform_class (
214
- search_space = self .search_space_with_target ,
215
- observations = [],
216
- )
197
+ t = UnitX (search_space = self .search_space_with_target )
217
198
new_search_space_with_target = SearchSpace (
218
199
parameters = [
219
200
RangeParameter (
@@ -227,50 +208,30 @@ def test_TransformNewSearchSpace(self) -> None:
227
208
]
228
209
)
229
210
t .transform_search_space (new_search_space_with_target )
230
- self .assertEqual (
231
- new_search_space_with_target .parameters ["x" ].target_value ,
232
- 0.5 * self .target_range + self .target_lb ,
233
- )
211
+ self .assertEqual (new_search_space_with_target .parameters ["x" ].target_value , 0.5 )
234
212
235
213
def test_w_robust_search_space_univariate (self ) -> None :
236
214
# Check that if no transforms are needed, it is untouched.
237
215
for multivariate in (True , False ):
238
- rss = get_robust_search_space (
239
- multivariate = multivariate ,
240
- lb = self .target_lb ,
241
- ub = self .target_ub ,
242
- )
216
+ rss = get_robust_search_space (multivariate = multivariate , lb = 0.0 , ub = 1.0 )
243
217
expected = str (rss )
244
- t = self .transform_class (
245
- search_space = rss ,
246
- observations = [],
247
- )
218
+ t = UnitX (search_space = rss )
248
219
self .assertEqual (expected , str (t .transform_search_space (rss )))
249
220
# Error if distribution is multiplicative.
250
221
rss = get_robust_search_space ()
251
222
rss .parameter_distributions [0 ].multiplicative = True
252
- t = self .transform_class (
253
- search_space = rss ,
254
- observations = [],
255
- )
223
+ t = UnitX (search_space = rss )
256
224
with self .assertRaisesRegex (NotImplementedError , "multiplicative" ):
257
225
t .transform_search_space (rss )
258
226
# Correctly transform univariate additive distributions.
259
227
rss = get_robust_search_space (lb = 5.0 , ub = 10.0 )
260
- t = self .transform_class (
261
- search_space = rss ,
262
- observations = [],
263
- )
228
+ t = UnitX (search_space = rss )
264
229
t .transform_search_space (rss )
265
230
dists = rss .parameter_distributions
266
- self .assertEqual (
267
- dists [0 ].distribution_parameters ["loc" ], 0.2 * self .target_range
268
- )
269
- self .assertEqual (dists [0 ].distribution_parameters ["scale" ], self .target_range )
231
+ self .assertEqual (dists [0 ].distribution_parameters ["loc" ], 0.2 )
232
+ self .assertEqual (dists [0 ].distribution_parameters ["scale" ], 1.0 )
270
233
self .assertEqual (dists [1 ].distribution_parameters ["loc" ], 0.0 )
271
- self .assertEqual (
272
- dists [1 ].distribution_parameters ["scale" ], 0.2 * self .target_range
273
- )
234
+ self .assertEqual (dists [1 ].distribution_parameters ["scale" ], 0.2 )
274
235
# Correctly transform environmental distributions.
275
236
rss = get_robust_search_space (lb = 5.0 , ub = 10.0 )
276
237
all_parameters = list (rss .parameters .values ())
@@ -286,22 +247,19 @@ def test_w_robust_search_space_univariate(self) -> None:
286
247
dist .distribution_parameters ["loc" ],
287
248
t ._normalize_value (1.0 , (5.0 , 10.0 )),
288
249
)
289
- self .assertEqual (dist .distribution_parameters ["scale" ], self . target_range )
250
+ self .assertEqual (dist .distribution_parameters ["scale" ], 1.0 )
290
251
# Error if transform via loc / scale is not supported.
291
252
rss = get_robust_search_space (use_discrete = True )
292
253
rss .parameters ["z" ]._parameter_type = ParameterType .FLOAT
293
- t = self .transform_class (
294
- search_space = rss ,
295
- observations = [],
296
- )
254
+ t = UnitX (search_space = rss )
297
255
with self .assertRaisesRegex (UnsupportedError , "`loc` and `scale`" ):
298
256
t .transform_search_space (rss )
299
257
300
258
def test_w_robust_search_space_multivariate (self ) -> None :
301
259
# Error if trying to transform non-normal multivariate distributions.
302
260
rss = get_robust_search_space (multivariate = True )
303
261
rss .parameter_distributions [0 ].distribution_class = "multivariate_t"
304
- t = self . transform_class (
262
+ t = UnitX (
305
263
search_space = rss ,
306
264
observations = [],
307
265
)
@@ -310,25 +268,16 @@ def test_w_robust_search_space_multivariate(self) -> None:
310
268
# Transform multivariate normal.
311
269
rss = get_robust_search_space (multivariate = True )
312
270
old_params = deepcopy (rss .parameter_distributions [0 ].distribution_parameters )
313
- t = self .transform_class (
314
- search_space = rss ,
315
- observations = [],
316
- )
271
+ t = UnitX (search_space = rss )
317
272
t .transform_search_space (rss )
318
273
new_params = rss .parameter_distributions [0 ].distribution_parameters
319
274
self .assertIsInstance (new_params ["mean" ], np .ndarray )
320
275
self .assertIsInstance (new_params ["cov" ], np .ndarray )
321
276
self .assertTrue (
322
- np .allclose (
323
- new_params ["mean" ],
324
- np .asarray (old_params ["mean" ]) / 5.0 * self .target_range ,
325
- )
277
+ np .allclose (new_params ["mean" ], np .asarray (old_params ["mean" ]) / 5.0 )
326
278
)
327
279
self .assertTrue (
328
- np .allclose (
329
- new_params ["cov" ],
330
- np .asarray (old_params ["cov" ]) / ((5.0 / self .target_range ) ** 2 ),
331
- )
280
+ np .allclose (new_params ["cov" ], np .asarray (old_params ["cov" ]) / 25.0 )
332
281
)
333
282
# Transform multivariate normal environmental distribution.
334
283
rss = get_robust_search_space (multivariate = True )
@@ -339,18 +288,11 @@ def test_w_robust_search_space_multivariate(self) -> None:
339
288
num_samples = rss .num_samples ,
340
289
environmental_variables = rss_params [:2 ],
341
290
)
342
- t = self .transform_class (
343
- search_space = rss ,
344
- observations = [],
345
- )
291
+ t = UnitX (search_space = rss )
346
292
t .transform_search_space (rss )
347
293
new_params = rss .parameter_distributions [0 ].distribution_parameters
348
294
self .assertTrue (
349
- np .allclose (
350
- new_params ["mean" ],
351
- np .asarray (old_params ["mean" ]) / 5.0 * self .target_range
352
- + self .target_lb ,
353
- )
295
+ np .allclose (new_params ["mean" ], np .asarray (old_params ["mean" ]) / 5.0 )
354
296
)
355
297
# Errors if mean / cov are of wrong shape.
356
298
rss .parameter_distributions [0 ].distribution_parameters ["mean" ] = [1.0 ]
@@ -360,3 +302,42 @@ def test_w_robust_search_space_multivariate(self) -> None:
360
302
rss .parameter_distributions [0 ].distribution_parameters ["cov" ] = [1.0 ]
361
303
with self .assertRaisesRegex (UserInputError , "cov" ):
362
304
t .transform_search_space (rss )
305
+
306
+ def test_transform_experiment_data (self ) -> None :
307
+ parameterizations = [
308
+ {"x" : 1.0 , "y" : 1.5 , "z" : 1.0 , "a" : 1 , "b" : "b" },
309
+ {"x" : 2.0 , "y" : 2.0 , "z" : 2.0 , "a" : 2 , "b" : "b" },
310
+ ]
311
+ experiment = get_experiment_with_observations (
312
+ observations = [[1.0 ], [2.0 ]],
313
+ search_space = self .search_space ,
314
+ parameterizations = parameterizations ,
315
+ )
316
+ experiment_data = extract_experiment_data (
317
+ experiment = experiment , data_loader_config = DataLoaderConfig ()
318
+ )
319
+ transformed_data = self .t .transform_experiment_data (
320
+ experiment_data = deepcopy (experiment_data )
321
+ )
322
+
323
+ # Check that `x` and `y` have been transformed.
324
+ expected = DataFrame (
325
+ index = transformed_data .arm_data .index ,
326
+ data = {
327
+ "x" : [0.0 , 0.5 ],
328
+ "y" : [0.5 , 1.0 ],
329
+ },
330
+ columns = ["x" , "y" ],
331
+ )
332
+ assert_frame_equal (transformed_data .arm_data [["x" , "y" ]], expected )
333
+
334
+ # Remaining columns are unchanged.
335
+ # "z" is log-scale and "a" is in, so they're not transformed.
336
+ cols = ["z" , "a" , "b" , "metadata" ]
337
+ assert_frame_equal (
338
+ transformed_data .arm_data [cols ], experiment_data .arm_data [cols ]
339
+ )
340
+ # Observation data is unchanged.
341
+ assert_frame_equal (
342
+ transformed_data .observation_data , experiment_data .observation_data
343
+ )
0 commit comments