-
Notifications
You must be signed in to change notification settings - Fork 2k
Expand file tree
/
Copy pathAutoMLBuildSpec.java
More file actions
383 lines (314 loc) · 13.5 KB
/
AutoMLBuildSpec.java
File metadata and controls
383 lines (314 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
package ai.h2o.automl;
import ai.h2o.automl.preprocessing.PreprocessingStepDefinition;
import hex.Model;
import hex.ScoreKeeper.StoppingMetric;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.HyperSpaceSearchCriteria;
import water.H2O;
import water.Iced;
import water.Key;
import water.exceptions.H2OIllegalValueException;
import water.fvec.Frame;
import water.util.*;
import water.util.PojoUtils.FieldNaming;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
/**
* Parameters which specify the build (or extension) of an AutoML build job.
*/
public class AutoMLBuildSpec extends Iced {
private static final ThreadLocal<DateFormat> instanceTimeStampFormat = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyyMMdd_Hmmss"));
private final static AtomicInteger amlInstanceCounter = new AtomicInteger();
/**
* The specification of overall build parameters for the AutoML process.
*/
public static final class AutoMLBuildControl extends Iced {
public final AutoMLStoppingCriteria stopping_criteria;
/**
* Identifier for models that should be grouped together in the leaderboard (e.g., "airlines" and "iris").
*/
public String project_name = null;
// Pass through to all algorithms
public boolean balance_classes = false;
public float[] class_sampling_factors;
public float max_after_balance_size = 5.0f;
public int nfolds = -1;
public DistributionFamily distribution = DistributionFamily.AUTO;
public String custom_distribution_func;
public double tweedie_power = 1.5;
public double quantile_alpha = 0.5;
public double huber_alpha = 0.9;
public String custom_metric_func;
public boolean keep_cross_validation_predictions = false;
public boolean keep_cross_validation_models = false;
public boolean keep_cross_validation_fold_assignment = false;
public String export_checkpoints_dir = null;
public AutoMLBuildControl() {
stopping_criteria = new AutoMLStoppingCriteria();
}
}
public static final class AutoMLStoppingCriteria extends Iced {
public static final int AUTO_STOPPING_TOLERANCE = -1;
public static double default_stopping_tolerance_for_frame(Frame frame) {
return HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria.default_stopping_tolerance_for_frame(frame);
}
private final HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria _searchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
private double _max_runtime_secs_per_model = 0;
public AutoMLStoppingCriteria() {
// reasonable defaults:
set_max_models(0); // no limit
set_max_runtime_secs(0); // no limit
set_max_runtime_secs_per_model(0); // no limit
set_stopping_rounds(3);
set_stopping_tolerance(AUTO_STOPPING_TOLERANCE);
set_stopping_metric(StoppingMetric.AUTO);
}
public double max_runtime_secs_per_model() {
return _max_runtime_secs_per_model;
}
public void set_max_runtime_secs_per_model(double max_runtime_secs_per_model) {
_max_runtime_secs_per_model = max_runtime_secs_per_model;
}
public long seed() {
return _searchCriteria.seed();
}
public int max_models() {
return _searchCriteria.max_models();
}
public double max_runtime_secs() {
return _searchCriteria.max_runtime_secs();
}
public int stopping_rounds() {
return _searchCriteria.stopping_rounds();
}
public StoppingMetric stopping_metric() {
return _searchCriteria.stopping_metric();
}
public double stopping_tolerance() {
return _searchCriteria.stopping_tolerance();
}
public void set_seed(long seed) {
_searchCriteria.set_seed(seed);
}
public void set_max_models(int max_models) {
_searchCriteria.set_max_models(max_models);
}
public void set_max_runtime_secs(double max_runtime_secs) {
_searchCriteria.set_max_runtime_secs(max_runtime_secs);
}
public void set_stopping_rounds(int stopping_rounds) {
_searchCriteria.set_stopping_rounds(stopping_rounds);
}
public void set_stopping_metric(StoppingMetric stopping_metric) {
_searchCriteria.set_stopping_metric(stopping_metric);
}
public void set_stopping_tolerance(double stopping_tolerance) {
_searchCriteria.set_stopping_tolerance(stopping_tolerance);
}
public void set_default_stopping_tolerance_for_frame(Frame frame) {
_searchCriteria.set_default_stopping_tolerance_for_frame(frame);
}
public HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria getSearchCriteria() {
return _searchCriteria;
}
}
/**
* The specification of the datasets to be used for the AutoML process.
* The user can specify a directory path, a file path (including HDFS, s3 or the like),
* or the ID of an already-parsed Frame in the H2O cluster. Paths are processed
* as usual in H2O.
*/
public static final class AutoMLInput extends Iced {
public Key<Frame> training_frame;
public Key<Frame> validation_frame;
public Key<Frame> blending_frame;
public Key<Frame> leaderboard_frame;
public String response_column;
public String fold_column;
public String weights_column;
public String[] ignored_columns;
public String sort_metric = StoppingMetric.AUTO.name();
}
/**
* The specification of the parameters for building models for a single algo (e.g., GBM), including base model parameters and hyperparameter search.
*/
public static final class AutoMLBuildModels extends Iced {
public Algo[] exclude_algos;
public Algo[] include_algos;
public StepDefinition[] modeling_plan;
public double exploitation_ratio = -1;
public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters();
public PreprocessingStepDefinition[] preprocessing;
}
public static final class AutoMLCustomParameters extends Iced {
// convenient property to allow us to modify our model (and later grids) definitions
// and benchmark them without having to rebuild the backend for each change.
static final String ALGO_PARAMS_ALL_ENABLED = H2O.OptArgs.SYSTEM_PROP_PREFIX + "automl.algo_parameters.all.enabled";
// let's limit the list of allowed custom parameters by default for now: we can always decide to open this later.
private static final String[] ALLOWED_PARAMETERS = {
"monotone_constraints",
"auc_type"
// "ntrees",
};
private static final String ROOT_PARAM = "algo_parameters";
public static final class AutoMLCustomParameter<V> extends Iced {
private AutoMLCustomParameter(String name, V value) {
_name = name;
_value = value;
}
private AutoMLCustomParameter(IAlgo algo, String name, V value) {
_algo = algo;
_name = name;
_value = value;
}
private IAlgo _algo;
private String _name;
private V _value;
}
public static final class Builder {
private final transient List<AutoMLCustomParameter> _anyAlgoParams = new ArrayList<>();
private final transient List<AutoMLCustomParameter> _specificAlgoParams = new ArrayList<>();
public <V> Builder add(String param, V value) {
assertParameterAllowed(param);
_anyAlgoParams.add(new AutoMLCustomParameter<>(param, value));
return this;
}
public <V> Builder add(IAlgo algo, String param, V value) {
assertParameterAllowed(param);
_specificAlgoParams.add(new AutoMLCustomParameter<>(algo, param, value));
return this;
}
/**
* Builder is necessary here as the custom parameters must be applied in a certain order,
* and we can't assume that the consumer of this API will add them in the right order.
* @return a new AutoMLCustomParameters instance with custom parameters properly assigned.
*/
public AutoMLCustomParameters build() {
AutoMLCustomParameters instance = new AutoMLCustomParameters();
// apply "all" scope first, then algo-specific ones.
for (AutoMLCustomParameter param : _anyAlgoParams) {
if (!instance.addParameter(param._name, param._value))
throw new H2OIllegalValueException(param._name, ROOT_PARAM, param._value);
}
for (AutoMLCustomParameter param : _specificAlgoParams) {
if (!instance.addParameter(param._algo, param._name, param._value))
throw new H2OIllegalValueException(param._name, ROOT_PARAM, param._value);
}
return instance;
}
private void assertParameterAllowed(String param) {
if (!Boolean.parseBoolean(System.getProperty(ALGO_PARAMS_ALL_ENABLED, "false"))
&& !ArrayUtils.contains(ALLOWED_PARAMETERS, param))
throw new H2OIllegalValueException(ROOT_PARAM, param);
}
}
public static Builder create() {
return new Builder();
}
private final IcedHashMap<String, String[]> _algoParameterNames = new IcedHashMap<>(); // stores the parameters names overridden, by algo name
private final IcedHashMap<String, Model.Parameters> _algoParameters = new IcedHashMap<>(); //stores the parameters values, by algo name
public boolean hasCustomParams(IAlgo algo) {
return _algoParameterNames.get(algo.name()) != null;
}
public boolean hasCustomParam(IAlgo algo, String param) {
return ArrayUtils.contains(_algoParameterNames.get(algo.name()), param);
}
public void applyCustomParameters(IAlgo algo, Model.Parameters destParams) {
if (hasCustomParams(algo)) {
String[] paramNames = getCustomParameterNames(algo);
String[] onlyParamNames = Stream.of(paramNames).map(p -> "_"+p).toArray(String[]::new);
PojoUtils.copyProperties(destParams, getCustomizedDefaults(algo), FieldNaming.CONSISTENT, null, onlyParamNames);
}
}
String[] getCustomParameterNames(IAlgo algo) {
return _algoParameterNames.get(algo.name());
}
Model.Parameters getCustomizedDefaults(IAlgo algo) {
if (!_algoParameters.containsKey(algo.name())) {
Model.Parameters defaults = defaultParameters(algo);
if (defaults != null) _algoParameters.put(algo.name(), defaults);
}
return _algoParameters.get(algo.name());
}
private Model.Parameters defaultParameters(IAlgo algo) {
return algo.enabled() ? ModelingStepsRegistry.defaultParameters(algo.name()) : null;
}
private void addParameterName(IAlgo algo, String param) {
if (!_algoParameterNames.containsKey(algo.name())) {
_algoParameterNames.put(algo.name(), new String[] {param});
} else {
String[] names = _algoParameterNames.get(algo.name());
if (!ArrayUtils.contains(names, param)) {
_algoParameterNames.put(algo.name(), ArrayUtils.append(names, param));
}
}
}
private <V> boolean addParameter(String param, V value) {
boolean added = false;
for (Algo algo : Algo.values()) {
added |= addParameter(algo, param, value);
}
return added;
}
private <V> boolean addParameter(IAlgo algo, String param, V value) {
Model.Parameters customParams = getCustomizedDefaults(algo);
try {
if (customParams != null
&& (setField(customParams, param, value, FieldNaming.DEST_HAS_UNDERSCORES)
|| setField(customParams, param, value, FieldNaming.CONSISTENT))) {
addParameterName(algo, param);
return true;
} else {
Log.debug("Could not set custom param " + param + " for algo " + algo);
return false;
}
} catch (IllegalArgumentException iae) {
throw new H2OIllegalValueException(param, ROOT_PARAM, value);
}
}
private <D, V> boolean setField(D dest, String fieldName, V value, FieldNaming naming) {
try {
PojoUtils.setField(dest, fieldName, value, naming);
return true;
} catch (IllegalArgumentException iae) {
// propagate exception iff the value was wrong (conversion issue), ignore if the field doesn't exist.
try {
PojoUtils.getFieldValue(dest, fieldName, naming);
} catch (IllegalArgumentException ignored){
return false;
}
throw iae;
}
}
}
public final AutoMLBuildControl build_control = new AutoMLBuildControl();
public final AutoMLInput input_spec = new AutoMLInput();
public final AutoMLBuildModels build_models = new AutoMLBuildModels();
private String instanceId;
public String project() {
if (build_control.project_name == null) {
build_control.project_name = instanceId();
}
return build_control.project_name;
}
public String instanceId() {
if (instanceId == null) {
instanceId = "AutoML_"+amlInstanceCounter.incrementAndGet()+"_"+ instanceTimeStampFormat.get().format(new Date());
}
return instanceId;
}
public Key<AutoML> makeKey() {
// if user offers a different response column,
// the new models will be added to a new Leaderboard, without removing the previous one.
// otherwise, the new models will be added to the existing leaderboard.
return Key.make(project() + AutoML.keySeparator + StringUtils.sanitizeIdentifier(input_spec.response_column));
}
public String[] getNonPredictors() {
return Arrays.stream(new String[]{input_spec.weights_column, input_spec.fold_column, input_spec.response_column})
.filter(Objects::nonNull)
.toArray(String[]::new);
}
}