Skip to content

Commit 81d5ad7

Browse files
committed
Add numpy nan comparison checker
1 parent 6a32876 commit 81d5ad7

File tree

15 files changed

+111
-12
lines changed

15 files changed

+111
-12
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for numpy nan comparison."""
6+
7+
from __future__ import annotations
8+
9+
from astroid import nodes
10+
from pylint.checkers import BaseChecker
11+
from pylint.checkers.utils import only_required_for_messages
12+
from pylint.interfaces import HIGH
13+
14+
COMPARISON_OP = frozenset(("<", "<=", ">", ">=", "!=", "=="))
15+
NUMPY_NAN = frozenset(("nan", "NaN", "NAN"))
16+
17+
18+
class NumpyNaNComparisonChecker(BaseChecker):
19+
name = "numpy-nan-compare"
20+
msgs = {
21+
"W8001": (
22+
"Numpy nan comparison used",
23+
"numpy-nan-compare",
24+
"Since comparing NaN with NaN always returns False, use np.isnan() to check for NaN values.",
25+
),
26+
}
27+
28+
@classmethod
29+
def __is_np_nan_call(cls, node):
30+
"""Check if the node represents a call to np.nan."""
31+
return (
32+
isinstance(node, nodes.Call)
33+
and isinstance(node.func, nodes.Attribute)
34+
and node.func.attrname in NUMPY_NAN
35+
and isinstance(node.func.expr, nodes.Name)
36+
and node.func.expr.name == "np"
37+
)
38+
39+
@only_required_for_messages("numpy-nan-compare")
40+
def visit_compare(self, node: nodes.Compare) -> None:
41+
# Why am I getting node: node.Call here? Should be nodes.Compare...
42+
# Check test case test_numpy/test_numpy_nan_comparison.py
43+
if self.__is_np_nan_call(node.left):
44+
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
45+
return
46+
47+
for op, comparator in node.ops:
48+
if op in COMPARISON_OP and self.__is_np_nan_call(comparator):
49+
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
50+
return

pylint_ml/plugin.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22

33
from pylint.lint import PyLinter
44

5-
from pylint_ml.checkers.numpy.import_numpy import NumpyImportChecker
6-
from pylint_ml.checkers.pandas.import_pandas import PandasImportChecker
7-
from pylint_ml.checkers.scipy.import_scipy import ScipyImportChecker
8-
from pylint_ml.checkers.sklearn.import_sklearn import SklearnImportChecker
9-
from pylint_ml.checkers.tensorflow.import_tensorflow import TensorflowImportChecker
10-
from pylint_ml.checkers.torch.import_torch import TorchImportChecker
5+
from pylint_ml.checkers.numpy.numpy_import import NumpyImportChecker
6+
from pylint_ml.checkers.numpy.numpy_nan_comparison import NumpyNaNComparisonChecker
7+
from pylint_ml.checkers.pandas.pandas_import import PandasImportChecker
8+
from pylint_ml.checkers.scipy.scipy_import import ScipyImportChecker
9+
from pylint_ml.checkers.sklearn.sklearn_import import SklearnImportChecker
10+
from pylint_ml.checkers.tensorflow.tensorflow_import import TensorflowImportChecker
11+
from pylint_ml.checkers.torch.torch_import import TorchImportChecker
1112

1213

1314
def register(linter: PyLinter) -> None:
1415
"""Register checkers."""
1516
# Numpy
1617
linter.register_checker(NumpyImportChecker(linter))
18+
linter.register_checker(NumpyNaNComparisonChecker(linter))
1719

1820
# Pandas
1921
linter.register_checker(PandasImportChecker(linter))

tests/checkers/test_numpy/test_import_numpy.py renamed to tests/checkers/test_numpy/test_numpy_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pylint.testutils
33
from pylint.interfaces import HIGH
44

5-
from pylint_ml.checkers.numpy.import_numpy import NumpyImportChecker
5+
from pylint_ml.checkers.numpy.numpy_import import NumpyImportChecker
66

77

88
class TestNumpyImport(pylint.testutils.CheckerTestCase):
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.numpy.numpy_nan_comparison import NumpyNaNComparisonChecker
6+
7+
8+
class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = NumpyNaNComparisonChecker
10+
11+
def test_correct_nan_compare(self):
12+
nan_compare_node = astroid.extract_node(
13+
"""
14+
np.isnan(np.nan)
15+
"""
16+
)
17+
18+
with self.assertNoMessages():
19+
self.checker.visit_compare(nan_compare_node)
20+
21+
def test_incorrect_nan_compare(self):
22+
nan_compare_node = astroid.extract_node(
23+
"""
24+
a_nan = np.array([0, 1, np.nan])
25+
print(a_nan)
26+
# [ 0. 1. nan]
27+
28+
print(a_nan == np.nan)
29+
# [False False False]
30+
31+
print(np.isnan(a_nan))
32+
# [False False True]
33+
34+
print(a_nan > 0)
35+
# [False True False]
36+
"""
37+
)
38+
39+
with self.assertAddsMessages(
40+
pylint.testutils.MessageTest(
41+
msg_id="numpy-nan-compare",
42+
confidence=HIGH,
43+
node=nan_compare_node,
44+
),
45+
ignore_position=True,
46+
):
47+
self.checker.visit_compare(nan_compare_node)

0 commit comments

Comments
 (0)