File tree Expand file tree Collapse file tree 4 files changed +70
-0
lines changed
tests/checkers/test_sklearn Expand file tree Collapse file tree 4 files changed +70
-0
lines changed Original file line number Diff line number Diff line change
1
+ # Licensed under the MIT: https://mit-license.org/
2
+ # For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3
+ # Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4
+
5
+ """Check for import of sklearn library."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from astroid import nodes
10
+ from pylint .checkers import BaseChecker
11
+ from pylint .checkers .utils import only_required_for_messages
12
+ from pylint .interfaces import HIGH
13
+
14
+
15
+ class SklearnImportChecker (BaseChecker ):
16
+ name = "sklearn-import"
17
+ msgs = {
18
+ "W8401" : (
19
+ "Direct or aliased Sklearn import detected" ,
20
+ "sklearn-import" ,
21
+ "Using `import sklearn` or `import sklearn as ...` is not recommended. For better clarity and consistency, "
22
+ "it is advisable to import specific submodules directly, using the `from sklearn import ...` syntax. "
23
+ "This approach prevents confusion and aligns with common practices by explicitly stating which "
24
+ "components of Sklearn are being used." ,
25
+ ),
26
+ }
27
+
28
+ @only_required_for_messages ("sklearn-import" )
29
+ def visit_import (self , node : nodes .Import ) -> None :
30
+ for name , _ in node .names :
31
+ if name == "sklearn" :
32
+ self .add_message ("sklearn-import" , node = node , confidence = HIGH )
Original file line number Diff line number Diff line change 5
5
from pylint_ml .checkers .numpy .import_numpy import NumpyImportChecker
6
6
from pylint_ml .checkers .pandas .import_pandas import PandasImportChecker
7
7
from pylint_ml .checkers .scipy .import_scipy import ScipyImportChecker
8
+ from pylint_ml .checkers .sklearn .import_sklearn import SklearnImportChecker
8
9
from pylint_ml .checkers .tensorflow .import_tensorflow import TensorflowImportChecker
9
10
from pylint_ml .checkers .torch .import_torch import TorchImportChecker
10
11
@@ -27,6 +28,7 @@ def register(linter: PyLinter) -> None:
27
28
linter .register_checker (ScipyImportChecker (linter ))
28
29
29
30
# Sklearn
31
+ linter .register_checker (SklearnImportChecker (linter ))
30
32
31
33
# Theano
32
34
# Matplotlib
Original file line number Diff line number Diff line change
1
+ import astroid
2
+ import pylint .testutils
3
+ from pylint .interfaces import HIGH
4
+
5
+ from pylint_ml .checkers .sklearn .import_sklearn import SklearnImportChecker
6
+
7
+
8
+ class TestSklearnImport (pylint .testutils .CheckerTestCase ):
9
+ CHECKER_CLASS = SklearnImportChecker
10
+
11
+ def test_correct_sklearn_import (self ):
12
+ sklearn_import_node = astroid .extract_node (
13
+ """
14
+ from sklearn import datasets
15
+ """
16
+ )
17
+
18
+ with self .assertNoMessages ():
19
+ self .checker .visit_import (sklearn_import_node )
20
+
21
+ def test_incorrect_sklearn_import (self ):
22
+ sklearn_import_node = astroid .extract_node (
23
+ """
24
+ import sklearn as skl
25
+ """
26
+ )
27
+
28
+ with self .assertAddsMessages (
29
+ pylint .testutils .MessageTest (
30
+ msg_id = "sklearn-import" ,
31
+ confidence = HIGH ,
32
+ node = sklearn_import_node ,
33
+ ),
34
+ ignore_position = True ,
35
+ ):
36
+ self .checker .visit_import (sklearn_import_node )
You can’t perform that action at this time.
0 commit comments