Skip to content

Commit 858dbf3

Browse files
authored
Merge pull request #55 from pylint-dev/54-add-library-parameter-checker
add lib parameter checkers
2 parents b8d5cdd + 5e0a201 commit 858dbf3

File tree

11 files changed

+729
-194
lines changed

11 files changed

+729
-194
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of numpy functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class NumPyParameterChecker(BaseChecker):
14+
name = "numpy-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in numpy method %s.",
18+
"numpy-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
REQUIRED_PARAMS = {
24+
# Array Creation
25+
"array": ["object"],
26+
"zeros": ["shape"],
27+
"ones": ["shape"],
28+
"full": ["shape", "fill_value"],
29+
"empty": ["shape"],
30+
"arange": ["start"],
31+
"linspace": ["start", "stop"],
32+
"logspace": ["start", "stop"],
33+
"eye": ["N"],
34+
"identity": ["n"],
35+
# Random Sampling
36+
"random.rand": ["d0"],
37+
"random.randn": ["d0"],
38+
"random.randint": ["low", "high"],
39+
"random.choice": ["a"],
40+
"random.uniform": ["low", "high"],
41+
"random.normal": ["loc", "scale"],
42+
# Mathematical Functions
43+
"sum": ["a"],
44+
"mean": ["a"],
45+
"median": ["a"],
46+
"std": ["a"],
47+
"var": ["a"],
48+
"prod": ["a"],
49+
"min": ["a"],
50+
"max": ["a"],
51+
"ptp": ["a"],
52+
# Array Manipulation
53+
"reshape": ["newshape"],
54+
"transpose": [],
55+
"concatenate": ["arrays"],
56+
"stack": ["arrays"],
57+
"vstack": ["arrays"],
58+
"hstack": ["arrays"],
59+
# Linear Algebra
60+
"dot": ["a", "b"],
61+
"matmul": ["a", "b"],
62+
"linalg.inv": ["a"],
63+
"linalg.eig": ["a"],
64+
"linalg.solve": ["a", "b"],
65+
# Statistical Functions
66+
"percentile": ["a", "q"],
67+
"quantile": ["a", "q"],
68+
"corrcoef": ["x"],
69+
"cov": ["m"],
70+
}
71+
72+
@only_required_for_messages("numpy-parameter")
73+
def visit_call(self, node: nodes.Call) -> None:
74+
method_name = self._get_full_method_name(node)
75+
76+
if method_name in self.REQUIRED_PARAMS:
77+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
78+
# Collect all missing parameters
79+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
80+
if missing_params:
81+
self.add_message(
82+
"numpy-parameter",
83+
node=node,
84+
confidence=HIGH,
85+
args=(", ".join(missing_params), method_name),
86+
)
87+
88+
@staticmethod
89+
def _get_full_method_name(node: nodes.Call) -> str:
90+
"""
91+
Extracts the full method name, including chained attributes (e.g., np.random.rand).
92+
"""
93+
func = node.func
94+
method_chain = []
95+
96+
# Traverse the attribute chain
97+
while isinstance(func, nodes.Attribute):
98+
method_chain.insert(0, func.attrname)
99+
func = func.expr
100+
101+
# Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix)
102+
if isinstance(func, nodes.Name) and func.name == "np":
103+
return ".".join(method_chain)
104+
return ""

pylint_ml/checkers/pandas/pandas_dataframe_merge.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

pylint_ml/checkers/pandas/pandas_dtype_param.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Pandas functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class PandasParameterChecker(BaseChecker):
14+
name = "pandas-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in Pandas method %s.",
18+
"pandas-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
# Define required parameters for specific Pandas classes and methods
24+
REQUIRED_PARAMS = {
25+
# DataFrame creation
26+
"DataFrame": ["data"], # The primary input data for DataFrame creation
27+
# Concatenation
28+
"concat": ["objs"], # The list or dictionary of DataFrames/Series to concatenate
29+
# DataFrame I/O (Input/Output)
30+
"read_csv": ["filepath_or_buffer", "dtype"], # Path to CSV file or file-like object; column data types
31+
"read_excel": ["io", "dtype"], # Path to Excel file or file-like object; column data types
32+
"read_table": ["filepath_or_buffer", "dtype"], # Path to delimited text-file or file object; column data types
33+
"to_csv": ["path_or_buf"], # File path or buffer to write the DataFrame to
34+
"to_excel": ["excel_writer"], # File path or ExcelWriter object to write the data to
35+
# Merging and Joining
36+
"merge": ["right", "how", "on", "validate"], # The DataFrame or Serie to merge with
37+
"join": ["other"], # The DataFrame or Series to join
38+
# DataFrame Operations
39+
"pivot_table": ["index"], # The column to pivot on (values and columns have defaults)
40+
"groupby": ["by"], # The key or list of keys to group by
41+
"resample": ["rule"], # The frequency rule to resample by
42+
# Data Cleaning and Transformation
43+
"fillna": ["value"], # Value to use to fill NA/NaN values
44+
"drop": ["labels"], # Labels to drop
45+
"drop_duplicates": ["subset"], # Subset of columns to consider when dropping duplicates
46+
"replace": ["to_replace"], # Values to replace
47+
# Plotting
48+
"plot": ["x"], # x-values or index for plotting
49+
"hist": ["column"], # Column to plot the histogram for
50+
"boxplot": ["column"], # Column(s) to plot boxplot for
51+
# DataFrame Sorting
52+
"sort_values": ["by"], # Column(s) to sort by
53+
"sort_index": ["axis"], # Axis to sort along (index=0, columns=1)
54+
# Statistical Functions
55+
"corr": ["method"], # Method to use for correlation ('pearson', 'kendall', 'spearman')
56+
"describe": [], # No required parameters, but additional ones could be specified
57+
# Windowing/Resampling Functions
58+
"rolling": ["window"], # Size of the moving window
59+
"ewm": ["span"], # Span for exponentially weighted calculations
60+
# Miscellaneous Functions
61+
"apply": ["func"], # Function to apply to the data
62+
"agg": ["func"], # Function or list of functions for aggregation
63+
}
64+
65+
@only_required_for_messages("pandas-parameter")
66+
def visit_call(self, node: nodes.Call) -> None:
67+
method_name = self._get_method_name(node)
68+
if method_name in self.REQUIRED_PARAMS:
69+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
70+
# Collect all missing parameters
71+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
72+
if missing_params:
73+
self.add_message(
74+
"pandas-parameter",
75+
node=node,
76+
confidence=HIGH,
77+
args=(", ".join(missing_params), method_name),
78+
)
79+
80+
@staticmethod
81+
def _get_method_name(node: nodes.Call) -> str:
82+
"""Extracts the method name from a Call node, including handling chained calls."""
83+
func = node.func
84+
while isinstance(func, nodes.Attribute):
85+
func = func.expr
86+
return (
87+
node.func.attrname
88+
if isinstance(node.func, nodes.Attribute)
89+
else func.name if isinstance(func, nodes.Name) else ""
90+
)

pylint_ml/checkers/scipy/scipy_import.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class ScipyImportChecker(BaseChecker):
2525
),
2626
}
2727

28-
@only_required_for_messages("scipy-import")
28+
@only_required_for_messages("scipy-import", "scipy-wildcard-import")
2929
def visit_import(self, node: nodes.Import) -> None:
30-
for name, _ in node.names:
30+
for name, _alias in node.names:
3131
if name == "scipy":
32+
# Flag direct or aliased imports of scipy
3233
self.add_message("scipy-import", node=node, confidence=HIGH)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Scipy functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class ScipyParameterChecker(BaseChecker):
14+
name = "scipy-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in scipy method %s.",
18+
"scipy-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
# Define required parameters for specific Scipy classes and methods
24+
REQUIRED_PARAMS = {
25+
# scipy.optimize
26+
"minimize": ["fun", "x0"],
27+
"curve_fit": ["f", "xdata", "ydata"],
28+
"root": ["fun", "x0"],
29+
# scipy.integrate
30+
"quad": ["func", "a", "b"],
31+
"dblquad": ["func", "a", "b", "gfun", "hfun"],
32+
"solve_ivp": ["fun", "t_span", "y0"],
33+
# scipy.stats
34+
"ttest_ind": ["a", "b"],
35+
"ttest_rel": ["a", "b"],
36+
"norm.pdf": ["x"],
37+
# scipy.spatial
38+
"distance.euclidean": ["u", "v"], # Full chain
39+
"euclidean": ["u", "v"], # Direct import of euclidean
40+
"KDTree.query": ["x"],
41+
}
42+
43+
@only_required_for_messages("scipy-parameter")
44+
def visit_call(self, node: nodes.Call) -> None:
45+
method_name = self._get_full_method_name(node)
46+
if method_name in self.REQUIRED_PARAMS:
47+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
48+
# Collect all missing parameters
49+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
50+
if missing_params:
51+
self.add_message(
52+
"scipy-parameter",
53+
node=node,
54+
confidence=HIGH,
55+
args=(", ".join(missing_params), method_name),
56+
)
57+
58+
def _get_full_method_name(self, node: nodes.Call) -> str:
59+
"""
60+
Extracts the full method name, including handling chained attributes (e.g., scipy.spatial.distance.euclidean)
61+
and also handles direct imports like euclidean.
62+
"""
63+
func = node.func
64+
method_chain = []
65+
66+
# Traverse the attribute chain to get the full method name
67+
while isinstance(func, nodes.Attribute):
68+
method_chain.insert(0, func.attrname)
69+
func = func.expr
70+
71+
# If it's a direct function name, like `euclidean`, return it
72+
if isinstance(func, nodes.Name):
73+
method_chain.insert(0, func.name)
74+
75+
return ".".join(method_chain)

0 commit comments

Comments
 (0)