Skip to content

Commit b05c639

Browse files
authored
Merge pull request #7 from pylint-dev/6-add-torch-import-checker
Add torch import checker
2 parents 2ac07fe + ceaabb3 commit b05c639

File tree

5 files changed

+107
-0
lines changed

5 files changed

+107
-0
lines changed

pylint_ml/checkers/torch/__init__.py

Whitespace-only changes.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for import of torch library."""
6+
7+
from __future__ import annotations
8+
9+
from astroid import nodes
10+
from pylint.checkers import BaseChecker
11+
from pylint.checkers.utils import only_required_for_messages
12+
from pylint.interfaces import HIGH
13+
14+
15+
class TorchImportChecker(BaseChecker):
16+
name = "torch-import"
17+
msgs = {
18+
"W8401": (
19+
"Torch imported with alias",
20+
"torch-import",
21+
"Torch should be imported without an alias to maintain consistency with common practices. "
22+
"Importing Torch with an alias can lead to confusion. "
23+
"Consider using `import torch` for clarity and adherence to the convention.",
24+
),
25+
"W8402": (
26+
"Direct import from Torch discouraged",
27+
"torch-importfrom",
28+
"Direct imports from Torch using `from torch import ...` are discouraged to maintain code "
29+
"clarity and prevent potential conflicts. Using any alias or direct import method can lead to confusion. "
30+
"Consider using `import torch` to adhere to the convention and ensure consistency.",
31+
),
32+
}
33+
34+
@only_required_for_messages("torch-import")
35+
def visit_import(self, node: nodes.Import) -> None:
36+
for name, alias in node.names:
37+
if name == "torch" and alias: # Alias is used
38+
self.add_message("torch-import", node=node, confidence=HIGH)
39+
40+
@only_required_for_messages("torch-importfrom")
41+
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
42+
if node.modname == "torch":
43+
self.add_message("torch-importfrom", node=node, confidence=HIGH)

pylint_ml/plugin.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from pylint_ml.checkers.numpy.import_numpy import NumpyImportChecker
66
from pylint_ml.checkers.pandas.import_pandas import PandasImportChecker
7+
from pylint_ml.checkers.tensorflow.import_tensorflow import TensorflowImportChecker
8+
from pylint_ml.checkers.torch.import_torch import TorchImportChecker
79

810

911
def register(linter: PyLinter) -> None:
@@ -13,3 +15,12 @@ def register(linter: PyLinter) -> None:
1315

1416
# Pandas
1517
linter.register_checker(PandasImportChecker(linter))
18+
19+
# Tensorflow
20+
linter.register_checker(TensorflowImportChecker(linter))
21+
22+
# Torch
23+
linter.register_checker(TorchImportChecker(linter))
24+
25+
# Sklearn
26+
# Scipy

tests/checkers/test_torch/__init__.py

Whitespace-only changes.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.torch.import_torch import TorchImportChecker
6+
7+
8+
class TestTorchImport(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = TorchImportChecker
10+
11+
def test_correct_torch_import(self):
12+
torch_import_node = astroid.extract_node(
13+
"""
14+
import torch
15+
"""
16+
)
17+
18+
with self.assertNoMessages():
19+
self.checker.visit_import(torch_import_node)
20+
21+
def test_incorrect_torch_import(self):
22+
torch_import_node = astroid.extract_node(
23+
"""
24+
import torch as th
25+
"""
26+
)
27+
28+
with self.assertAddsMessages(
29+
pylint.testutils.MessageTest(
30+
msg_id="torch-import",
31+
confidence=HIGH,
32+
node=torch_import_node,
33+
),
34+
ignore_position=True,
35+
):
36+
self.checker.visit_import(torch_import_node)
37+
38+
def test_incorrect_torch_import_from(self):
39+
torch_importfrom_node = astroid.extract_node(
40+
"""
41+
from torch import min
42+
"""
43+
)
44+
45+
with self.assertAddsMessages(
46+
pylint.testutils.MessageTest(
47+
msg_id="torch-importfrom",
48+
confidence=HIGH,
49+
node=torch_importfrom_node,
50+
),
51+
ignore_position=True,
52+
):
53+
self.checker.visit_importfrom(torch_importfrom_node)

0 commit comments

Comments
 (0)