Skip to content

Commit 1df441b

Browse files
authored
Merge pull request #13 from pylint-dev/12-add-numpy-nan-comparision-checker
Add numpy nan comparison checker
2 parents ae43c17 + 24f379c commit 1df441b

File tree

15 files changed

+101
-12
lines changed

15 files changed

+101
-12
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for numpy nan comparison."""
6+
7+
from __future__ import annotations
8+
9+
from astroid import nodes
10+
from pylint.checkers import BaseChecker
11+
from pylint.checkers.utils import only_required_for_messages
12+
from pylint.interfaces import HIGH
13+
14+
COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "=="))
15+
NUMPY_NAN = frozenset(("nan", "NaN", "NAN"))
16+
17+
18+
class NumpyNaNComparisonChecker(BaseChecker):
19+
name = "numpy-nan-compare"
20+
msgs = {
21+
"W8001": (
22+
"Numpy nan comparison used",
23+
"numpy-nan-compare",
24+
"Since comparing NaN with NaN always returns False, use np.isnan() to check for NaN values.",
25+
),
26+
}
27+
28+
@classmethod
29+
def __is_np_nan_call(cls, node: nodes.Attribute) -> bool:
30+
"""Check if the node represents a call to np.nan."""
31+
return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == "np"
32+
33+
@only_required_for_messages("numpy-nan-compare")
34+
def visit_compare(self, node: nodes.Compare) -> None:
35+
36+
if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left):
37+
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
38+
return
39+
40+
for op, comparator in node.ops:
41+
if op in COMPARISON_OP and isinstance(comparator, nodes.Attribute) and self.__is_np_nan_call(comparator):
42+
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
43+
return

pylint_ml/plugin.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22

33
from pylint.lint import PyLinter
44

5-
from pylint_ml.checkers.numpy.import_numpy import NumpyImportChecker
6-
from pylint_ml.checkers.pandas.import_pandas import PandasImportChecker
7-
from pylint_ml.checkers.scipy.import_scipy import ScipyImportChecker
8-
from pylint_ml.checkers.sklearn.import_sklearn import SklearnImportChecker
9-
from pylint_ml.checkers.tensorflow.import_tensorflow import TensorflowImportChecker
10-
from pylint_ml.checkers.torch.import_torch import TorchImportChecker
5+
from pylint_ml.checkers.numpy.numpy_import import NumpyImportChecker
6+
from pylint_ml.checkers.numpy.numpy_nan_comparison import NumpyNaNComparisonChecker
7+
from pylint_ml.checkers.pandas.pandas_import import PandasImportChecker
8+
from pylint_ml.checkers.scipy.scipy_import import ScipyImportChecker
9+
from pylint_ml.checkers.sklearn.sklearn_import import SklearnImportChecker
10+
from pylint_ml.checkers.tensorflow.tensorflow_import import TensorflowImportChecker
11+
from pylint_ml.checkers.torch.torch_import import TorchImportChecker
1112

1213

1314
def register(linter: PyLinter) -> None:
1415
"""Register checkers."""
1516
# Numpy
1617
linter.register_checker(NumpyImportChecker(linter))
18+
linter.register_checker(NumpyNaNComparisonChecker(linter))
1719

1820
# Pandas
1921
linter.register_checker(PandasImportChecker(linter))

tests/checkers/test_numpy/test_import_numpy.py renamed to tests/checkers/test_numpy/test_numpy_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pylint.testutils
33
from pylint.interfaces import HIGH
44

5-
from pylint_ml.checkers.numpy.import_numpy import NumpyImportChecker
5+
from pylint_ml.checkers.numpy.numpy_import import NumpyImportChecker
66

77

88
class TestNumpyImport(pylint.testutils.CheckerTestCase):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.numpy.numpy_nan_comparison import NumpyNaNComparisonChecker
6+
7+
8+
class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = NumpyNaNComparisonChecker
10+
11+
def test_singleton_nan_compare(self):
12+
code = """
13+
a_nan = np.array([0, 1, np.nan])
14+
15+
np.nan == a_nan #@
16+
17+
1 == 1 == np.nan #@
18+
19+
1 > 0 > np.nan #@
20+
21+
"""
22+
singleton_nan_compare, chained_nan_compare, great_than_nan_compare = astroid.extract_node(code)
23+
24+
with self.assertAddsMessages(
25+
pylint.testutils.MessageTest(
26+
msg_id="numpy-nan-compare",
27+
node=singleton_nan_compare,
28+
confidence=HIGH,
29+
),
30+
pylint.testutils.MessageTest(
31+
msg_id="numpy-nan-compare",
32+
node=chained_nan_compare,
33+
confidence=HIGH,
34+
),
35+
pylint.testutils.MessageTest(
36+
msg_id="numpy-nan-compare",
37+
node=great_than_nan_compare,
38+
confidence=HIGH,
39+
),
40+
ignore_position=True,
41+
):
42+
self.checker.visit_compare(singleton_nan_compare)
43+
self.checker.visit_compare(chained_nan_compare)
44+
self.checker.visit_compare(great_than_nan_compare)

0 commit comments

Comments
 (0)