Skip to content

Commit 92c1612

Browse files
authored
Merge pull request #59 from pylint-dev/58-update-test-environment
Update test
2 parents e867a66 + 8504db1 commit 92c1612

25 files changed

+229
-120
lines changed

pylint_ml/checkers/matplotlib/__init__.py

Whitespace-only changes.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Matplotlib functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers.utils import only_required_for_messages
9+
from pylint.interfaces import HIGH
10+
11+
from pylint_ml.util.library_handler import LibraryHandler
12+
13+
14+
class MatplotlibParameterChecker(LibraryHandler):
15+
name = "matplotlib-parameter"
16+
msgs = {
17+
"W8111": (
18+
"Ensure that required parameters %s are explicitly specified in matplotlib method %s.",
19+
"matplotlib-parameter",
20+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
21+
),
22+
}
23+
24+
# Define required parameters for specific matplotlib classes and methods
25+
REQUIRED_PARAMS = {
26+
# Plotting Functions
27+
"plot": ["x", "y"], # x and y data points are required for basic line plots
28+
"scatter": ["x", "y"], # x and y data points are required for scatter plots
29+
"bar": ["x", "height"], # x positions and heights are required for bar plots
30+
"hist": ["x"], # Data points (x) are required for histogram plots
31+
"pie": ["x"], # x data is required for pie chart slices
32+
"imshow": ["X"], # Input array (X) is required for displaying images
33+
"contour": ["X", "Y", "Z"], # X, Y, and Z data points are required for contour plots
34+
"contourf": ["X", "Y", "Z"], # X, Y, and Z data points for filled contour plots
35+
"pcolormesh": ["X", "Y", "C"], # X, Y grid and C color values are required for pseudo color plot
36+
# Axes Functions
37+
"set_xlabel": ["xlabel"], # xlabel is required for setting the x-axis label
38+
"set_ylabel": ["ylabel"], # ylabel is required for setting the y-axis label
39+
"set_xlim": ["left", "right"], # Left and right bounds for x-axis limit
40+
"set_ylim": ["bottom", "top"], # Bottom and top bounds for y-axis limit
41+
# Figures and Subplots
42+
"subplots": ["nrows", "ncols"], # Number of rows and columns are required for creating a subplot grid
43+
"subplot": ["nrows", "ncols", "index"], # Number of rows, columns, and index for specific subplot
44+
# Miscellaneous Functions
45+
"savefig": ["fname"], # Filename or file object is required to save a figure
46+
}
47+
48+
@only_required_for_messages("matplotlib-parameter")
49+
def visit_call(self, node: nodes.Call) -> None:
50+
# TODO Update
51+
# if not self.is_library_imported('matplotlib') and self.is_library_version_valid(lib_version=):
52+
# return
53+
54+
method_name = self._get_full_method_name(node)
55+
if method_name in self.REQUIRED_PARAMS:
56+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
57+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
58+
if missing_params:
59+
self.add_message(
60+
"matplotlib-parameter",
61+
node=node,
62+
confidence=HIGH,
63+
args=(", ".join(missing_params), method_name),
64+
)
65+
66+
def _get_full_method_name(self, node: nodes.Call) -> str:
67+
func = node.func
68+
method_chain = []
69+
70+
while isinstance(func, nodes.Attribute):
71+
method_chain.insert(0, func.attrname)
72+
func = func.expr
73+
if isinstance(func, nodes.Name):
74+
method_chain.insert(0, func.name)
75+
76+
return ".".join(method_chain)

pylint_ml/checkers/numpy/numpy_dot.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from __future__ import annotations
88

99
from astroid import nodes
10-
from pylint.checkers import BaseChecker
1110
from pylint.checkers.utils import only_required_for_messages
1211
from pylint.interfaces import HIGH
1312

13+
from pylint_ml.util.library_handler import LibraryHandler
1414

15-
class NumpyDotChecker(BaseChecker):
15+
16+
class NumpyDotChecker(LibraryHandler):
1617
name = "numpy-dot-checker"
1718
msgs = {
1819
"W8122": (
@@ -23,8 +24,14 @@ class NumpyDotChecker(BaseChecker):
2324
),
2425
}
2526

27+
def visit_import(self, node: nodes.Import):
28+
super().visit_import(node=node)
29+
2630
@only_required_for_messages("numpy-dot-usage")
2731
def visit_call(self, node: nodes.Call) -> None:
32+
if not self.is_library_imported("numpy"):
33+
return
34+
2835
# Check if the function being called is np.dot
2936
if isinstance(node.func, nodes.Attribute):
3037
func_name = node.func.attrname

pylint_ml/util/library_handler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pylint.checkers import BaseChecker
2+
3+
4+
class LibraryHandler(BaseChecker):
5+
6+
def __init__(self, linter):
7+
super().__init__(linter)
8+
self.imports = {}
9+
10+
def visit_import(self, node):
11+
for name, alias in node.names:
12+
self.imports[alias or name] = name
13+
14+
def visit_importfrom(
15+
self,
16+
node,
17+
):
18+
# TODO Update method to handle either:
19+
# 1. Check of specific method-name imported?
20+
# 2. Store all method names importfrom libname?
21+
22+
module = node.modname
23+
for name, alias in node.names:
24+
full_name = f"{module}.{name}"
25+
self.imports[alias or name] = full_name
26+
27+
def is_library_imported(self, library_name):
28+
return any(mod.startswith(library_name) for mod in self.imports.values())
29+
30+
# def is_library_version_valid(self, lib_version):
31+
# # TODO update solution
32+
# if lib_version is None:
33+
# pass
34+
# return

tests/checkers/test_numpy/test_numpy_dot.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ class TestNumpyDotChecker(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyDotChecker
1010

1111
def test_warning_for_dot(self):
12-
node = astroid.extract_node(
13-
"""
14-
import numpy as np
15-
a = np.array([1, 2])
16-
b = np.array([3, 4])
17-
result = np.dot(a, b) # [numpy-dot-usage]
12+
import_np, node = astroid.extract_node(
1813
"""
14+
import numpy as np #@
15+
a = np.array([1, 2])
16+
b = np.array([3, 4])
17+
np.dot(a, b) #@
18+
"""
1919
)
2020

21-
dot_call = node.value
22-
2321
with self.assertAddsMessages(
2422
pylint.testutils.MessageTest(
2523
msg_id="numpy-dot-usage",
24+
node=node,
2625
confidence=HIGH,
27-
node=dot_call,
2826
),
2927
ignore_position=True,
3028
):
31-
self.checker.visit_call(dot_call)
29+
self.checker.visit_import(import_np)
30+
self.checker.visit_call(node)

tests/checkers/test_numpy/test_numpy_import.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,45 +9,45 @@ class TestNumpyImport(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyImportChecker
1010

1111
def test_correct_numpy_import(self):
12-
numpy_import_node = astroid.extract_node(
12+
import_node = astroid.extract_node(
1313
"""
14-
import numpy as np
14+
import numpy as np #@
1515
"""
1616
)
1717

1818
with self.assertNoMessages():
19-
self.checker.visit_import(numpy_import_node)
19+
self.checker.visit_import(import_node)
2020

2121
def test_incorrect_numpy_import(self):
22-
numpy_import_node = astroid.extract_node(
22+
import_node = astroid.extract_node(
2323
"""
24-
import numpy as npy
24+
import numpy as npy #@
2525
"""
2626
)
2727

2828
with self.assertAddsMessages(
2929
pylint.testutils.MessageTest(
3030
msg_id="numpy-import",
3131
confidence=HIGH,
32-
node=numpy_import_node,
32+
node=import_node,
3333
),
3434
ignore_position=True,
3535
):
36-
self.checker.visit_import(numpy_import_node)
36+
self.checker.visit_import(import_node)
3737

3838
def test_incorrect_numpy_import_from(self):
39-
numpy_importfrom_node = astroid.extract_node(
39+
importfrom_node = astroid.extract_node(
4040
"""
41-
from numpy import min
41+
from numpy import min #@
4242
"""
4343
)
4444

4545
with self.assertAddsMessages(
4646
pylint.testutils.MessageTest(
4747
msg_id="numpy-importfrom",
4848
confidence=HIGH,
49-
node=numpy_importfrom_node,
49+
node=importfrom_node,
5050
),
5151
ignore_position=True,
5252
):
53-
self.checker.visit_importfrom(numpy_importfrom_node)
53+
self.checker.visit_importfrom(importfrom_node)

tests/checkers/test_numpy/test_numpy_nan_comparison.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,33 @@ class TestNumpyNaNComparison(pylint.testutils.CheckerTestCase):
99
CHECKER_CLASS = NumpyNaNComparisonChecker
1010

1111
def test_singleton_nan_compare(self):
12-
code = """
12+
singleton_node, chained_node, great_than_node = astroid.extract_node(
13+
"""
1314
a_nan = np.array([0, 1, np.nan])
14-
1515
np.nan == a_nan #@
16-
1716
1 == 1 == np.nan #@
18-
1917
1 > 0 > np.nan #@
20-
2118
"""
22-
singleton_nan_compare, chained_nan_compare, great_than_nan_compare = astroid.extract_node(code)
19+
)
2320

2421
with self.assertAddsMessages(
2522
pylint.testutils.MessageTest(
2623
msg_id="numpy-nan-compare",
27-
node=singleton_nan_compare,
24+
node=singleton_node,
2825
confidence=HIGH,
2926
),
3027
pylint.testutils.MessageTest(
3128
msg_id="numpy-nan-compare",
32-
node=chained_nan_compare,
29+
node=chained_node,
3330
confidence=HIGH,
3431
),
3532
pylint.testutils.MessageTest(
3633
msg_id="numpy-nan-compare",
37-
node=great_than_nan_compare,
34+
node=great_than_node,
3835
confidence=HIGH,
3936
),
4037
ignore_position=True,
4138
):
42-
self.checker.visit_compare(singleton_nan_compare)
43-
self.checker.visit_compare(chained_nan_compare)
44-
self.checker.visit_compare(great_than_nan_compare)
39+
self.checker.visit_compare(singleton_node)
40+
self.checker.visit_compare(chained_node)
41+
self.checker.visit_compare(great_than_node)

tests/checkers/test_numpy/test_numpy_parameter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_array_missing_object(self):
1212
node = astroid.extract_node(
1313
"""
1414
import numpy as np
15-
arr = np.array() # [numpy-parameter]
15+
arr = np.array() #@
1616
"""
1717
)
1818

@@ -33,7 +33,7 @@ def test_zeros_without_shape(self):
3333
node = astroid.extract_node(
3434
"""
3535
import numpy as np
36-
arr = np.zeros() # [numpy-parameter]
36+
arr = np.zeros() #@
3737
"""
3838
)
3939

@@ -54,7 +54,7 @@ def test_random_rand_without_shape(self):
5454
node = astroid.extract_node(
5555
"""
5656
import numpy as np
57-
arr = np.random.rand() # [numpy-parameter]
57+
arr = np.random.rand() #@
5858
"""
5959
)
6060

@@ -75,7 +75,7 @@ def test_dot_without_b(self):
7575
node = astroid.extract_node(
7676
"""
7777
import numpy as np
78-
arr = np.dot(a=[1, 2, 3]) # [numpy-parameter]
78+
arr = np.dot(a=[1, 2, 3]) #@
7979
"""
8080
)
8181

@@ -96,7 +96,7 @@ def test_percentile_without_q(self):
9696
node = astroid.extract_node(
9797
"""
9898
import numpy as np
99-
result = np.percentile(a=[1, 2, 3]) # [numpy-parameter]
99+
result = np.percentile(a=[1, 2, 3]) #@
100100
"""
101101
)
102102

tests/checkers/test_pandas/pandas_dataframe_column_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_incorrect_column_selection(self):
1313
"""
1414
import pandas as pd
1515
df_sales = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
16-
value = df_sales.A # [pandas-column-selection]
16+
value = df_sales.A #@
1717
"""
1818
)
1919

tests/checkers/test_pandas/test_pandas_dataframe_bool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_dataframe_bool_usage(self):
1313
"""
1414
import pandas as pd
1515
df_customers = pd.DataFrame(data)
16-
df_customers.bool() # [pandas-dataframe-bool]
16+
df_customers.bool() #@
1717
"""
1818
)
1919
with self.assertAddsMessages(
@@ -31,7 +31,7 @@ def test_no_bool_usage(self):
3131
"""
3232
import pandas as pd
3333
df_customers = pd.DataFrame(data)
34-
df_customers.sum() # This should pass without warnings
34+
df_customers.sum() #@
3535
"""
3636
)
3737
with self.assertNoMessages():

0 commit comments

Comments
 (0)