Skip to content

Commit 24f379c

Browse files
committed
Add nan compare checker final
1 parent 81d5ad7 commit 24f379c

File tree

2 files changed

+26
-36
lines changed

2 files changed

+26
-36
lines changed

pylint_ml/checkers/numpy/numpy_nan_comparison.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,18 @@ class NumpyNaNComparisonChecker(BaseChecker):
2626
}
2727

2828
@classmethod
29-
def __is_np_nan_call(cls, node):
29+
def __is_np_nan_call(cls, node: nodes.Attribute) -> bool:
3030
"""Check if the node represents a call to np.nan."""
31-
return (
32-
isinstance(node, nodes.Call)
33-
and isinstance(node.func, nodes.Attribute)
34-
and node.func.attrname in NUMPY_NAN
35-
and isinstance(node.func.expr, nodes.Name)
36-
and node.func.expr.name == "np"
37-
)
31+
return node.attrname in NUMPY_NAN and isinstance(node.expr, nodes.Name) and node.expr.name == "np"
3832

3933
@only_required_for_messages("numpy-nan-compare")
4034
def visit_compare(self, node: nodes.Compare) -> None:
41-
# Why am I getting node: node.Call here? Should be nodes.Compare...
42-
# Check test case test_numpy/test_numpy_nan_comparison.py
43-
if self.__is_np_nan_call(node.left):
35+
36+
if isinstance(node.left, nodes.Attribute) and self.__is_np_nan_call(node.left):
4437
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
4538
return
4639

4740
for op, comparator in node.ops:
48-
if op in COMPARISON_OP and self.__is_np_nan_call(comparator):
41+
if op in COMPARISON_OP and isinstance(comparator, nodes.Attribute) and self.__is_np_nan_call(comparator):
4942
self.add_message("numpy-nan-compare", node=node, confidence=HIGH)
5043
return

tests/checkers/test_numpy/test_numpy_nan_comparison.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,37 @@
88
class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyNaNComparisonChecker
1010

11-
def test_correct_nan_compare(self):
12-
nan_compare_node = astroid.extract_node(
13-
"""
14-
np.isnan(np.nan)
15-
"""
16-
)
17-
18-
with self.assertNoMessages():
19-
self.checker.visit_compare(nan_compare_node)
20-
21-
def test_incorrect_nan_compare(self):
22-
nan_compare_node = astroid.extract_node(
23-
"""
11+
def test_singleton_nan_compare(self):
12+
code = """
2413
a_nan = np.array([0, 1, np.nan])
25-
print(a_nan)
26-
# [ 0. 1. nan]
2714
28-
print(a_nan == np.nan)
29-
# [False False False]
15+
np.nan == a_nan #@
3016
31-
print(np.isnan(a_nan))
32-
# [False False True]
17+
1 == 1 == np.nan #@
18+
19+
1 > 0 > np.nan #@
3320
34-
print(a_nan > 0)
35-
# [False True False]
3621
"""
37-
)
22+
singleton_nan_compare, chained_nan_compare, great_than_nan_compare = astroid.extract_node(code)
3823

3924
with self.assertAddsMessages(
4025
pylint.testutils.MessageTest(
4126
msg_id="numpy-nan-compare",
27+
node=singleton_nan_compare,
28+
confidence=HIGH,
29+
),
30+
pylint.testutils.MessageTest(
31+
msg_id="numpy-nan-compare",
32+
node=chained_nan_compare,
33+
confidence=HIGH,
34+
),
35+
pylint.testutils.MessageTest(
36+
msg_id="numpy-nan-compare",
37+
node=great_than_nan_compare,
4238
confidence=HIGH,
43-
node=nan_compare_node,
4439
),
4540
ignore_position=True,
4641
):
47-
self.checker.visit_compare(nan_compare_node)
42+
self.checker.visit_compare(singleton_nan_compare)
43+
self.checker.visit_compare(chained_nan_compare)
44+
self.checker.visit_compare(great_than_nan_compare)

0 commit comments

Comments
 (0)