Skip to content

Commit a48b0af

Browse files
authored
Make all references to sklearn.utils.validation.check_is_fitted uniform (#806)
1 parent c16247e commit a48b0af

File tree

9 files changed

+14
-18
lines changed

9 files changed

+14
-18
lines changed

dask_ml/_compat.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import os
33
from collections.abc import Mapping # noqa
4-
from typing import Any, List, Optional, Union
4+
from typing import Any
55

66
import dask
77
import dask.array as da
@@ -43,12 +43,6 @@ def dummy_context(*args: Any, **kwargs: Any):
4343
blockwise = da.blockwise
4444

4545

46-
def check_is_fitted(est, attributes: Optional[Union[str, List[str]]] = None):
47-
args: Any = ()
48-
49-
return sklearn.utils.validation.check_is_fitted(est, *args)
50-
51-
5246
def _check_multimetric_scoring(estimator, scoring=None):
5347
# TODO: See if scikit-learn 0.24 solves the need for using
5448
# a private method

dask_ml/cluster/k_means.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from dask import compute
1111
from sklearn.base import BaseEstimator, TransformerMixin
1212
from sklearn.utils.extmath import squared_norm
13+
from sklearn.utils.validation import check_is_fitted
1314

14-
from .._compat import SK_024, blockwise, check_is_fitted
15+
from .._compat import SK_024, blockwise
1516
from .._utils import draw_seed
1617
from ..metrics import (
1718
euclidean_distances,

dask_ml/decomposition/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import sklearn.decomposition
88
from dask import compute, delayed
99
from sklearn.utils.extmath import fast_logdet
10-
from sklearn.utils.validation import check_random_state
10+
from sklearn.utils.validation import check_is_fitted, check_random_state
1111

12-
from .._compat import DASK_2_26_0, check_is_fitted
12+
from .._compat import DASK_2_26_0
1313
from .._utils import draw_seed
1414
from ..utils import svd_flip
1515

dask_ml/model_selection/_incremental.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from sklearn.model_selection import ParameterGrid, ParameterSampler
2525
from sklearn.utils import check_random_state
2626
from sklearn.utils.metaestimators import if_delegate_has_method
27+
from sklearn.utils.validation import check_is_fitted
2728

28-
from .._compat import DISTRIBUTED_2021_02_0, annotate, check_is_fitted, dummy_context
29+
from .._compat import DISTRIBUTED_2021_02_0, annotate, dummy_context
2930
from .._typing import ArrayLike, Int
3031
from .._utils import LoggingContext
3132
from ..wrappers import ParallelPostFit

dask_ml/model_selection/_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333
from sklearn.pipeline import FeatureUnion, Pipeline
3434
from sklearn.utils.metaestimators import if_delegate_has_method
3535
from sklearn.utils.multiclass import type_of_target
36-
from sklearn.utils.validation import _num_samples
36+
from sklearn.utils.validation import _num_samples, check_is_fitted
3737

38-
from .._compat import SK_024, SK_VERSION, check_is_fitted
38+
from .._compat import SK_024, SK_VERSION
3939
from ._normalize import normalize_estimator
4040
from .methods import (
4141
MISSING,

dask_ml/model_selection/_successive_halving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import numpy as np
55
import toolz
6+
from sklearn.utils.validation import check_is_fitted
67

7-
from .._compat import check_is_fitted
88
from ._incremental import IncrementalSearchCV
99

1010

dask_ml/preprocessing/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from pandas.api.types import is_categorical_dtype
1616
from scipy import stats
1717
from sklearn.base import BaseEstimator, TransformerMixin
18-
from sklearn.utils.validation import check_random_state
18+
from sklearn.utils.validation import check_is_fitted, check_random_state
1919

20-
from dask_ml._compat import blockwise, check_is_fitted
20+
from dask_ml._compat import blockwise
2121
from dask_ml._utils import copy_learned_attributes
2222
from dask_ml.utils import check_array, handle_zeros_in_scale
2323

dask_ml/preprocessing/label.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import pandas as pd
1010
import scipy.sparse
1111
import sklearn.preprocessing
12+
from sklearn.utils.validation import check_is_fitted
1213

13-
from .._compat import check_is_fitted
1414
from .._typing import ArrayLike, SeriesType
1515

1616

dask_ml/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import numpy as np
88
import sklearn.base
99
import sklearn.metrics
10+
from sklearn.utils.validation import check_is_fitted
1011

1112
from dask_ml.utils import _timer
1213

13-
from ._compat import check_is_fitted
1414
from ._partial import fit
1515
from ._utils import copy_learned_attributes
1616
from .metrics import check_scoring, get_scorer

0 commit comments

Comments
 (0)