Skip to content

Commit 70f8367

Browse files
committed
Add sklearn import checker
1 parent 0f82c82 commit 70f8367

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for import of sklearn library."""
6+
7+
from __future__ import annotations
8+
9+
from astroid import nodes
10+
from pylint.checkers import BaseChecker
11+
from pylint.checkers.utils import only_required_for_messages
12+
from pylint.interfaces import HIGH
13+
14+
15+
class SklearnImportChecker(BaseChecker):
16+
name = "sklearn-import"
17+
msgs = {
18+
"W8401": (
19+
"Direct or aliased Sklearn import detected",
20+
"sklearn-import",
21+
"Using `import sklearn` or `import sklearn as ...` is not recommended. For better clarity and consistency, "
22+
"it is advisable to import specific submodules directly, using the `from sklearn import ...` syntax. "
23+
"This approach prevents confusion and aligns with common practices by explicitly stating which "
24+
"components of Sklearn are being used.",
25+
),
26+
}
27+
28+
@only_required_for_messages("sklearn-import")
29+
def visit_import(self, node: nodes.Import) -> None:
30+
for name, _ in node.names:
31+
if name == "sklearn":
32+
self.add_message("sklearn-import", node=node, confidence=HIGH)

pylint_ml/plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pylint_ml.checkers.numpy.import_numpy import NumpyImportChecker
66
from pylint_ml.checkers.pandas.import_pandas import PandasImportChecker
77
from pylint_ml.checkers.scipy.import_scipy import ScipyImportChecker
8+
from pylint_ml.checkers.sklearn.import_sklearn import SklearnImportChecker
89
from pylint_ml.checkers.tensorflow.import_tensorflow import TensorflowImportChecker
910
from pylint_ml.checkers.torch.import_torch import TorchImportChecker
1011

@@ -27,6 +28,7 @@ def register(linter: PyLinter) -> None:
2728
linter.register_checker(ScipyImportChecker(linter))
2829

2930
# Sklearn
31+
linter.register_checker(SklearnImportChecker(linter))
3032

3133
# Theano
3234
# Matplotlib

tests/checkers/test_sklearn/__init__.py

Whitespace-only changes.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.sklearn.import_sklearn import SklearnImportChecker
6+
7+
8+
class TestSklearnImport(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = SklearnImportChecker
10+
11+
def test_correct_sklearn_import(self):
12+
sklearn_import_node = astroid.extract_node(
13+
"""
14+
from sklearn import datasets
15+
"""
16+
)
17+
18+
with self.assertNoMessages():
19+
self.checker.visit_import(sklearn_import_node)
20+
21+
def test_incorrect_sklearn_import(self):
22+
sklearn_import_node = astroid.extract_node(
23+
"""
24+
import sklearn as skl
25+
"""
26+
)
27+
28+
with self.assertAddsMessages(
29+
pylint.testutils.MessageTest(
30+
msg_id="sklearn-import",
31+
confidence=HIGH,
32+
node=sklearn_import_node,
33+
),
34+
ignore_position=True,
35+
):
36+
self.checker.visit_import(sklearn_import_node)

0 commit comments

Comments
 (0)