Skip to content

Commit 0f03651

Browse files
authored
Fixes sources of indeterminism for tests and Pipeline steps (#1234)
* Fixed module idempotent tests to match sklearns descriptions * Added pytest.ini so examples with "test" in name don't run * random_state now documented and used properly * flake8'd * Removed stray print statement * Fixed failing tests * Added testing for pipeline steps random_states * Added random_state to Pipelines that have fit called * flake8'd * Made mlp tests less sensitive to platform differences * review changes
1 parent 19a9573 commit 0f03651

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+744
-430
lines changed

autosklearn/automl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import joblib
3333
import sklearn.utils
3434
from scipy.sparse import spmatrix
35+
from sklearn.utils import check_random_state
3536
from sklearn.utils.validation import check_is_fitted
3637
from sklearn.metrics._classification import type_of_target
3738
from sklearn.dummy import DummyClassifier, DummyRegressor
@@ -1165,7 +1166,7 @@ def refit(self, X, y):
11651166
if self.ensemble_ is None:
11661167
raise ValueError("Refit can only be called if 'ensemble_size != 0'")
11671168

1168-
random_state = np.random.RandomState(self._seed)
1169+
random_state = check_random_state(self._seed)
11691170
for identifier in self.models_:
11701171
model = self.models_[identifier]
11711172
# this updates the model inplace, it can then later be used in

autosklearn/ensemble_builder.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020
import pandas as pd
2121
import pynisher
22-
from sklearn.utils.validation import check_random_state
2322
from smac.callbacks import IncorporateRunResultCallback
2423
from smac.optimizer.smbo import SMBO
2524
from smac.runhistory.runhistory import RunInfo, RunValue
@@ -57,7 +56,7 @@ def __init__(
5756
max_iterations: Optional[int],
5857
read_at_most: int,
5958
ensemble_memory_limit: Optional[int],
60-
random_state: int,
59+
random_state: Union[int, np.random.RandomState],
6160
logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
6261
pynisher_context: str = 'fork',
6362
):
@@ -228,7 +227,7 @@ def build_ensemble(
228227
precision=self.precision,
229228
memory_limit=self.ensemble_memory_limit,
230229
read_at_most=self.read_at_most,
231-
random_state=self.seed,
230+
random_state=self.random_state,
232231
end_at=self.start_time + self.time_left_for_ensembles,
233232
iteration=self.iteration,
234233
return_predictions=False,
@@ -266,15 +265,15 @@ def fit_and_return_ensemble(
266265
max_models_on_disc: Union[float, int],
267266
seed: int,
268267
precision: int,
269-
memory_limit: Optional[int],
270268
read_at_most: int,
271-
random_state: int,
272269
end_at: float,
273270
iteration: int,
274271
return_predictions: bool,
275272
pynisher_context: str,
276273
logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
277274
unit_test: bool = False,
275+
memory_limit: Optional[int] = None,
276+
random_state: Optional[Union[int, np.random.RandomState]] = None,
278277
) -> Tuple[
279278
List[Tuple[int, float, float, float]],
280279
int,
@@ -318,8 +317,6 @@ def fit_and_return_ensemble(
318317
random seed
319318
precision: [16,32,64,128]
320319
precision of floats to read the predictions
321-
memory_limit: Optional[int]
322-
memory limit in mb. If ``None``, no memory limit is enforced.
323320
read_at_most: int
324321
read at most n new prediction files in each iteration
325322
end_at: float
@@ -329,13 +326,17 @@ def fit_and_return_ensemble(
329326
The current iteration
330327
pynisher_context: str
331328
Context to use for multiprocessing, can be either fork, spawn or forkserver.
332-
logger_port: int
329+
logger_port: int = DEFAULT_TCP_LOGGING_PORT
333330
The port where the logging server is listening to.
334-
unit_test: bool
331+
unit_test: bool = False
335332
Turn on unit testing mode. This currently makes fit_ensemble raise a MemoryError.
336333
Having this is very bad coding style, but I did not find a way to make
337334
unittest.mock work through the pynisher with all spawn contexts. If you know a
338335
better solution, please let us know by opening an issue.
336+
memory_limit: Optional[int] = None
337+
memory limit in mb. If ``None``, no memory limit is enforced.
338+
random_state: Optional[int | RandomState] = None
339+
A random state used for the ensemble selection process.
339340
340341
Returns
341342
-------
@@ -376,15 +377,15 @@ def __init__(
376377
task_type: int,
377378
metric: Scorer,
378379
ensemble_size: int = 10,
379-
ensemble_nbest: int = 100,
380+
ensemble_nbest: Union[int, float] = 100,
380381
max_models_on_disc: int = 100,
381382
performance_range_threshold: float = 0,
382383
seed: int = 1,
383384
precision: int = 32,
384385
memory_limit: Optional[int] = 1024,
385386
read_at_most: int = 5,
386-
random_state: Optional[Union[int, np.random.RandomState]] = None,
387387
logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
388+
random_state: Optional[Union[int, np.random.RandomState]] = None,
388389
unit_test: bool = False,
389390
):
390391
"""
@@ -400,14 +401,14 @@ def __init__(
400401
type of ML task
401402
metric: str
402403
name of metric to compute the loss of the given predictions
403-
ensemble_size: int
404+
ensemble_size: int = 10
404405
maximal size of ensemble (passed to autosklearn.ensemble.ensemble_selection)
405-
ensemble_nbest: int/float
406+
ensemble_nbest: int | float = 100
406407
if int: consider only the n best prediction
407408
if float: consider only this fraction of the best models
408-
Both wrt to validation predictions
409+
Both with respect to the validation predictions
409410
If performance_range_threshold > 0, might return less models
410-
max_models_on_disc: int
411+
max_models_on_disc: int = 100
411412
Defines the maximum number of models that are kept in the disc.
412413
If int, it must be greater or equal than 1, and dictates the max number of
413414
models to keep.
@@ -417,23 +418,25 @@ def __init__(
417418
Models and predictions of the worst-performing models will be deleted then.
418419
If None, the feature is disabled.
419420
It defines an upper bound on the models that can be used in the ensemble.
420-
performance_range_threshold: float
421+
performance_range_threshold: float = 0
421422
Keep only models that are better than:
422423
dummy + (best - dummy)*performance_range_threshold
423424
E.g dummy=2, best=4, thresh=0.5 --> only consider models with loss > 3
424425
Will at most return the minimum between ensemble_nbest models,
425426
and max_models_on_disc. Might return less
426-
seed: int
427-
random seed
428-
precision: [16,32,64,128]
427+
seed: int = 1
428+
random seed that is used as part of the filename
429+
precision: int in [16,32,64,128] = 32
429430
precision of floats to read the predictions
430-
memory_limit: Optional[int]
431+
memory_limit: Optional[int] = 1024
431432
memory limit in mb. If ``None``, no memory limit is enforced.
432-
read_at_most: int
433+
read_at_most: int = 5
433434
read at most n new prediction files in each iteration
434-
logger_port: int
435+
logger_port: int = DEFAULT_TCP_LOGGING_PORT
435436
port that receives logging records
436-
unit_test: bool
437+
random_state: Optional[int | RandomState] = None
438+
An int or RandomState object used for generating the ensemble.
439+
unit_test: bool = False
437440
Turn on unit testing mode. This currently makes fit_ensemble raise a MemoryError.
438441
Having this is very bad coding style, but I did not find a way to make
439442
unittest.mock work through the pynisher with all spawn contexts. If you know a
@@ -475,7 +478,7 @@ def __init__(
475478
self.precision = precision
476479
self.memory_limit = memory_limit
477480
self.read_at_most = read_at_most
478-
self.random_state = check_random_state(random_state)
481+
self.random_state = random_state
479482
self.unit_test = unit_test
480483

481484
# Setup the logger

autosklearn/ensembles/ensemble_selection.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import random
22
from collections import Counter
3-
from typing import Any, Dict, List, Tuple, Union, cast
3+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
44

55
import numpy as np
66

7+
from sklearn.utils import check_random_state
8+
79
from autosklearn.constants import TASK_TYPES
810
from autosklearn.ensembles.abstract_ensemble import AbstractEnsemble
911
from autosklearn.metrics import Scorer, calculate_loss
@@ -16,15 +18,46 @@ def __init__(
1618
ensemble_size: int,
1719
task_type: int,
1820
metric: Scorer,
19-
random_state: np.random.RandomState,
2021
bagging: bool = False,
2122
mode: str = 'fast',
23+
random_state: Optional[Union[int, np.random.RandomState]] = None,
2224
) -> None:
25+
""" An ensemble of selected algorithms
26+
27+
Fitting an EnsembleSelection generates an ensemble from the the models
28+
generated during the search process. Can be further used for prediction.
29+
30+
Parameters
31+
----------
32+
task_type: int
33+
An identifier indicating which task is being performed.
34+
metric: Scorer
35+
The metric used to evaluate the models
36+
bagging: bool = False
37+
Whether to use bagging in ensemble selection
38+
mode: str in ['fast', 'slow'] = 'fast'
39+
Which kind of ensemble generation to use
40+
* 'slow' - The original method used in Rich Caruana's ensemble selection.
41+
* 'fast' - A faster version of Rich Caruanas' ensemble selection.
42+
43+
random_state: Optional[int | RandomState] = None
44+
The random_state used for ensemble selection.
45+
* None - Uses numpy's default RandomState object
46+
* int - Successive calls to fit will produce the same results
47+
* RandomState - Truely random, each call to fit will produce
48+
different results, even with the same object.
49+
"""
2350
self.ensemble_size = ensemble_size
2451
self.task_type = task_type
2552
self.metric = metric
2653
self.bagging = bagging
2754
self.mode = mode
55+
56+
# Behaviour similar to sklearn
57+
# int - Deteriministic with succesive calls to fit
58+
# RandomState - Successive calls to fit will produce differences
59+
# None - Uses numpmys global singleton RandomState
60+
# https://scikit-learn.org/stable/common_pitfalls.html#controlling-randomness
2861
self.random_state = random_state
2962

3063
def __getstate__(self) -> Dict[str, Any]:
@@ -84,6 +117,7 @@ def _fast(
84117
) -> None:
85118
"""Fast version of Rich Caruana's ensemble selection method."""
86119
self.num_input_models_ = len(predictions)
120+
rand = check_random_state(self.random_state)
87121

88122
ensemble = [] # type: List[np.ndarray]
89123
trajectory = []
@@ -143,7 +177,9 @@ def _fast(
143177
)
144178

145179
all_best = np.argwhere(losses == np.nanmin(losses)).flatten()
146-
best = self.random_state.choice(all_best)
180+
181+
best = rand.choice(all_best)
182+
147183
ensemble.append(predictions[best])
148184
trajectory.append(losses[best])
149185
order.append(best)

autosklearn/evaluation/abstract_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class MyDummyClassifier(DummyClassifier):
4646
def __init__(
4747
self,
4848
config: Configuration,
49-
random_state: np.random.RandomState,
49+
random_state: Optional[Union[int, np.random.RandomState]],
5050
init_params: Optional[Dict[str, Any]] = None,
5151
dataset_properties: Dict[str, Any] = {},
5252
include: Optional[List[str]] = None,
@@ -102,7 +102,7 @@ class MyDummyRegressor(DummyRegressor):
102102
def __init__(
103103
self,
104104
config: Configuration,
105-
random_state: np.random.RandomState,
105+
random_state: Optional[Union[int, np.random.RandomState]],
106106
init_params: Optional[Dict[str, Any]] = None,
107107
dataset_properties: Dict[str, Any] = {},
108108
include: Optional[List[str]] = None,

autosklearn/metalearning/metafeatures/metafeatures.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1106,9 +1106,11 @@ def calculate_all_metafeatures(X, y, categorical, dataset_name, logger,
11061106
X_transformed = check_array(X_transformed,
11071107
force_all_finite=True,
11081108
accept_sparse='csr')
1109-
rs = np.random.RandomState(42)
11101109
indices = np.arange(X_transformed.shape[0])
1110+
1111+
rs = np.random.RandomState(42)
11111112
rs.shuffle(indices)
1113+
11121114
# TODO Shuffle inplace
11131115
X_transformed = X_transformed[indices]
11141116
y_transformed = y[indices]

autosklearn/pipeline/base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import scipy.sparse
99

1010
from sklearn.pipeline import Pipeline
11-
from sklearn.utils.validation import check_random_state
1211

1312
from .components.base import AutoSklearnChoice, AutoSklearnComponent
1413
import autosklearn.pipeline.create_searchspace_util
@@ -43,6 +42,7 @@ def __init__(self, config=None, steps=None, dataset_properties=None,
4342
self.exclude = exclude if exclude is not None else {}
4443
self.dataset_properties = dataset_properties if \
4544
dataset_properties is not None else {}
45+
self.random_state = random_state
4646

4747
if steps is None:
4848
self.steps = self._get_pipeline_steps(dataset_properties=dataset_properties)
@@ -73,10 +73,6 @@ def __init__(self, config=None, steps=None, dataset_properties=None,
7373

7474
self.set_hyperparameters(self.config, init_params=init_params)
7575

76-
if random_state is None:
77-
self.random_state = check_random_state(1)
78-
else:
79-
self.random_state = check_random_state(random_state)
8076
super().__init__(steps=self.steps)
8177

8278
self._additional_run_info = {}

0 commit comments

Comments
 (0)