Skip to content

Commit 28f2d7d

Browse files
author
Peter Hamfelt
committed
Update test
1 parent e867a66 commit 28f2d7d

26 files changed

+280
-138
lines changed

pylint_ml/checkers/matplotlib/__init__.py

Whitespace-only changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Matplotlib functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers.utils import only_required_for_messages
9+
from pylint.interfaces import HIGH
10+
11+
from pylint_ml.util.library_handler import LibraryHandler
12+
13+
14+
class MatplotlibParameterChecker(LibraryHandler):
15+
name = "matplotlib-parameter"
16+
msgs = {
17+
"W8111": (
18+
"Ensure that required parameters %s are explicitly specified in matplotlib method %s.",
19+
"matplotlib-parameter",
20+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
21+
),
22+
}
23+
24+
# Define required parameters for specific matplotlib classes and methods
25+
REQUIRED_PARAMS = {
26+
# Plotting Functions
27+
'plot': ['x', 'y'], # x and y data points are required for basic line plots
28+
'scatter': ['x', 'y'], # x and y data points are required for scatter plots
29+
'bar': ['x', 'height'], # x positions and heights are required for bar plots
30+
'hist': ['x'], # Data points (x) are required for histogram plots
31+
'pie': ['x'], # x data is required for pie chart slices
32+
'imshow': ['X'], # Input array (X) is required for displaying images
33+
'contour': ['X', 'Y', 'Z'], # X, Y, and Z data points are required for contour plots
34+
'contourf': ['X', 'Y', 'Z'], # X, Y, and Z data points for filled contour plots
35+
'pcolormesh': ['X', 'Y', 'C'], # X, Y grid and C color values are required for pseudo color plot
36+
37+
# Axes Functions
38+
'set_xlabel': ['xlabel'], # xlabel is required for setting the x-axis label
39+
'set_ylabel': ['ylabel'], # ylabel is required for setting the y-axis label
40+
'set_xlim': ['left', 'right'], # Left and right bounds for x-axis limit
41+
'set_ylim': ['bottom', 'top'], # Bottom and top bounds for y-axis limit
42+
43+
# Figures and Subplots
44+
'subplots': ['nrows', 'ncols'], # Number of rows and columns are required for creating a subplot grid
45+
'subplot': ['nrows', 'ncols', 'index'], # Number of rows, columns, and index for specific subplot
46+
47+
# Miscellaneous Functions
48+
'savefig': ['fname'], # Filename or file object is required to save a figure
49+
}
50+
51+
@only_required_for_messages("matplotlib-parameter")
52+
def visit_call(self, node: nodes.Call) -> None:
53+
if not self.is_library_imported('matplotlib') and self.is_library_version_valid(lib_version=):
54+
return
55+
56+
method_name = self._get_full_method_name(node)
57+
if method_name in self.REQUIRED_PARAMS:
58+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
59+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
60+
if missing_params:
61+
self.add_message(
62+
"matplotlib-parameter",
63+
node=node,
64+
confidence=HIGH,
65+
args=(", ".join(missing_params), method_name),
66+
)
67+
68+
def _get_full_method_name(self, node: nodes.Call) -> str:
69+
func = node.func
70+
method_chain = []
71+
72+
while isinstance(func, nodes.Attribute):
73+
method_chain.insert(0, func.attrname)
74+
func = func.expr
75+
if isinstance(func, nodes.Name):
76+
method_chain.insert(0, func.name)
77+
78+
return ".".join(method_chain)

pylint_ml/checkers/numpy/numpy_dot.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from __future__ import annotations
88

99
from astroid import nodes
10-
from pylint.checkers import BaseChecker
1110
from pylint.checkers.utils import only_required_for_messages
1211
from pylint.interfaces import HIGH
1312

13+
from pylint_ml.util.library_handler import LibraryHandler
1414

15-
class NumpyDotChecker(BaseChecker):
15+
16+
class NumpyDotChecker(LibraryHandler):
1617
name = "numpy-dot-checker"
1718
msgs = {
1819
"W8122": (
@@ -23,8 +24,14 @@ class NumpyDotChecker(BaseChecker):
2324
),
2425
}
2526

27+
def visit_import(self, node: nodes.Import):
28+
super().visit_import(node=node)
29+
2630
@only_required_for_messages("numpy-dot-usage")
2731
def visit_call(self, node: nodes.Call) -> None:
32+
if not self.is_library_imported('numpy'):
33+
return
34+
2835
# Check if the function being called is np.dot
2936
if isinstance(node.func, nodes.Attribute):
3037
func_name = node.func.attrname

pylint_ml/util/library_handler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from pylint.checkers import BaseChecker
2+
3+
4+
class LibraryHandler(BaseChecker):
5+
6+
def __init__(self, linter):
7+
super().__init__(linter)
8+
self.imports = {}
9+
10+
def visit_import(self, node):
11+
for name, alias in node.names:
12+
self.imports[alias or name] = name
13+
14+
def visit_importfrom(self, node, ):
15+
# TODO Update method to handle either:
16+
# 1. Check of specific method-name imported?
17+
# 2. Store all method names importfrom libname?
18+
19+
module = node.modname
20+
for name, alias in node.names:
21+
full_name = f"{module}.{name}"
22+
self.imports[alias or name] = full_name
23+
24+
def is_library_imported(self, library_name):
25+
return any(mod.startswith(library_name) for mod in self.imports.values())
26+
27+
def is_library_version_valid(self, lib_version):
28+
# TODO update solution
29+
return

tests/checkers/test_numpy/test_numpy_dot.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,20 @@ class TestNumpyDotChecker(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyDotChecker
1010

1111
def test_warning_for_dot(self):
12-
node = astroid.extract_node(
13-
"""
14-
import numpy as np
15-
a = np.array([1, 2])
16-
b = np.array([3, 4])
17-
result = np.dot(a, b) # [numpy-dot-usage]
18-
"""
19-
)
20-
21-
dot_call = node.value
12+
import_np, node = astroid.extract_node("""
13+
import numpy as np #@
14+
a = np.array([1, 2])
15+
b = np.array([3, 4])
16+
np.dot(a, b) #@
17+
""")
2218

2319
with self.assertAddsMessages(
24-
pylint.testutils.MessageTest(
25-
msg_id="numpy-dot-usage",
26-
confidence=HIGH,
27-
node=dot_call,
28-
),
29-
ignore_position=True,
20+
pylint.testutils.MessageTest(
21+
msg_id="numpy-dot-usage",
22+
node=node,
23+
confidence=HIGH,
24+
),
25+
ignore_position=True
3026
):
31-
self.checker.visit_call(dot_call)
27+
self.checker.visit_import(import_np)
28+
self.checker.visit_call(node)

tests/checkers/test_numpy/test_numpy_import.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,45 @@ class TestNumpyImport(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyImportChecker
1010

1111
def test_correct_numpy_import(self):
12-
numpy_import_node = astroid.extract_node(
12+
import_node = astroid.extract_node(
1313
"""
14-
import numpy as np
14+
import numpy as np #@
1515
"""
1616
)
1717

1818
with self.assertNoMessages():
19-
self.checker.visit_import(numpy_import_node)
19+
self.checker.visit_import(import_node)
2020

2121
def test_incorrect_numpy_import(self):
22-
numpy_import_node = astroid.extract_node(
22+
import_node = astroid.extract_node(
2323
"""
24-
import numpy as npy
24+
import numpy as npy #@
2525
"""
2626
)
2727

2828
with self.assertAddsMessages(
2929
pylint.testutils.MessageTest(
3030
msg_id="numpy-import",
3131
confidence=HIGH,
32-
node=numpy_import_node,
32+
node=import_node,
3333
),
3434
ignore_position=True,
3535
):
36-
self.checker.visit_import(numpy_import_node)
36+
self.checker.visit_import(import_node)
3737

3838
def test_incorrect_numpy_import_from(self):
39-
numpy_importfrom_node = astroid.extract_node(
39+
importfrom_node = astroid.extract_node(
4040
"""
41-
from numpy import min
41+
from numpy import min #@
4242
"""
4343
)
4444

4545
with self.assertAddsMessages(
4646
pylint.testutils.MessageTest(
4747
msg_id="numpy-importfrom",
4848
confidence=HIGH,
49-
node=numpy_importfrom_node,
49+
node=importfrom_node,
5050
),
5151
ignore_position=True,
5252
):
53-
self.checker.visit_importfrom(numpy_importfrom_node)
53+
self.checker.visit_importfrom(importfrom_node)

tests/checkers/test_numpy/test_numpy_nan_comparison.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,31 @@ class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyNaNComparisonChecker
1010

1111
def test_singleton_nan_compare(self):
12-
code = """
12+
singleton_node, chained_node, great_than_node = astroid.extract_node("""
1313
a_nan = np.array([0, 1, np.nan])
14-
1514
np.nan == a_nan #@
16-
1715
1 == 1 == np.nan #@
18-
1916
1 > 0 > np.nan #@
20-
21-
"""
22-
singleton_nan_compare, chained_nan_compare, great_than_nan_compare = astroid.extract_node(code)
17+
""")
2318

2419
with self.assertAddsMessages(
2520
pylint.testutils.MessageTest(
2621
msg_id="numpy-nan-compare",
27-
node=singleton_nan_compare,
22+
node=singleton_node,
2823
confidence=HIGH,
2924
),
3025
pylint.testutils.MessageTest(
3126
msg_id="numpy-nan-compare",
32-
node=chained_nan_compare,
27+
node=chained_node,
3328
confidence=HIGH,
3429
),
3530
pylint.testutils.MessageTest(
3631
msg_id="numpy-nan-compare",
37-
node=great_than_nan_compare,
32+
node=great_than_node,
3833
confidence=HIGH,
3934
),
4035
ignore_position=True,
4136
):
42-
self.checker.visit_compare(singleton_nan_compare)
43-
self.checker.visit_compare(chained_nan_compare)
44-
self.checker.visit_compare(great_than_nan_compare)
37+
self.checker.visit_compare(singleton_node)
38+
self.checker.visit_compare(chained_node)
39+
self.checker.visit_compare(great_than_node)

tests/checkers/test_numpy/test_numpy_parameter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_array_missing_object(self):
1212
node = astroid.extract_node(
1313
"""
1414
import numpy as np
15-
arr = np.array() # [numpy-parameter]
15+
arr = np.array() #@
1616
"""
1717
)
1818

@@ -33,7 +33,7 @@ def test_zeros_without_shape(self):
3333
node = astroid.extract_node(
3434
"""
3535
import numpy as np
36-
arr = np.zeros() # [numpy-parameter]
36+
arr = np.zeros() #@
3737
"""
3838
)
3939

@@ -54,7 +54,7 @@ def test_random_rand_without_shape(self):
5454
node = astroid.extract_node(
5555
"""
5656
import numpy as np
57-
arr = np.random.rand() # [numpy-parameter]
57+
arr = np.random.rand() #@
5858
"""
5959
)
6060

@@ -75,7 +75,7 @@ def test_dot_without_b(self):
7575
node = astroid.extract_node(
7676
"""
7777
import numpy as np
78-
arr = np.dot(a=[1, 2, 3]) # [numpy-parameter]
78+
arr = np.dot(a=[1, 2, 3]) #@
7979
"""
8080
)
8181

@@ -96,7 +96,7 @@ def test_percentile_without_q(self):
9696
node = astroid.extract_node(
9797
"""
9898
import numpy as np
99-
result = np.percentile(a=[1, 2, 3]) # [numpy-parameter]
99+
result = np.percentile(a=[1, 2, 3]) #@
100100
"""
101101
)
102102

tests/checkers/test_pandas/pandas_dataframe_column_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_incorrect_column_selection(self):
1313
"""
1414
import pandas as pd
1515
df_sales = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
16-
value = df_sales.A # [pandas-column-selection]
16+
value = df_sales.A #@
1717
"""
1818
)
1919

tests/checkers/test_pandas/test_pandas_dataframe_bool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_dataframe_bool_usage(self):
1313
"""
1414
import pandas as pd
1515
df_customers = pd.DataFrame(data)
16-
df_customers.bool() # [pandas-dataframe-bool]
16+
df_customers.bool() #@
1717
"""
1818
)
1919
with self.assertAddsMessages(
@@ -31,7 +31,7 @@ def test_no_bool_usage(self):
3131
"""
3232
import pandas as pd
3333
df_customers = pd.DataFrame(data)
34-
df_customers.sum() # This should pass without warnings
34+
df_customers.sum() #@
3535
"""
3636
)
3737
with self.assertNoMessages():

0 commit comments

Comments
 (0)