Skip to content

Commit a5fd5a8

Browse files
author
Peter Hamfelt
committed
Add numpy dot checker
1 parent d7bfd95 commit a5fd5a8

File tree

3 files changed

+66
-1
lines changed

3 files changed

+66
-1
lines changed

pylint_ml/checkers/numpy/numpy_dot.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for the use of np.dot and recommend np.matmul for matrix multiplication."""
6+
7+
from __future__ import annotations
8+
9+
from astroid import nodes
10+
from pylint.checkers import BaseChecker
11+
from pylint.checkers.utils import only_required_for_messages
12+
from pylint.interfaces import HIGH
13+
14+
15+
class NumpyDotChecker(BaseChecker):
16+
name = "numpy-dot-checker"
17+
msgs = {
18+
"W8122": (
19+
"Consider using 'np.matmul()' instead of 'np.dot()' for matrix multiplication.",
20+
"numpy-dot-usage",
21+
"It's recommended to use 'np.matmul()' for matrix multiplication, which is more explicit and handles "
22+
"higher-dimensional arrays more consistently. ",
23+
),
24+
}
25+
26+
@only_required_for_messages("numpy-dot-usage")
27+
def visit_call(self, node: nodes.Call) -> None:
28+
# Check if the function being called is np.dot
29+
if isinstance(node.func, nodes.Attribute):
30+
func_name = node.func.attrname
31+
module_name = getattr(node.func.expr, "name", None)
32+
33+
if func_name == "dot" and module_name == "np":
34+
# Suggest using np.matmul() instead
35+
self.add_message("numpy-dot-usage", node=node, confidence=HIGH)

pylint_ml/checkers/pandas/pandas_series_naming.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class PandasSeriesNamingChecker(BaseChecker):
2424

2525
@only_required_for_messages("pandas-series-naming")
2626
def visit_assign(self, node: nodes.Assign) -> None:
27-
print(node)
2827
if isinstance(node.value, nodes.Call):
2928
func_name = getattr(node.value.func, "attrname", None)
3029
module_name = getattr(node.value.func.expr, "name", None)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.numpy.numpy_dot import NumpyDotChecker
6+
7+
8+
class TestNumpyDotChecker(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = NumpyDotChecker
10+
11+
def test_warning_for_dot(self):
12+
node = astroid.extract_node(
13+
"""
14+
import numpy as np
15+
a = np.array([1, 2])
16+
b = np.array([3, 4])
17+
result = np.dot(a, b) # [numpy-dot-usage]
18+
"""
19+
)
20+
21+
dot_call = node.value
22+
23+
with self.assertAddsMessages(
24+
pylint.testutils.MessageTest(
25+
msg_id="numpy-dot-usage",
26+
confidence=HIGH,
27+
node=dot_call,
28+
),
29+
ignore_position=True,
30+
):
31+
self.checker.visit_call(dot_call)

0 commit comments

Comments
 (0)