Skip to content

Commit 4eaf6d8

Browse files
author
Peter Hamfelt
committed
add lib parameter checkers
1 parent b8d5cdd commit 4eaf6d8

File tree

11 files changed

+751
-194
lines changed

11 files changed

+751
-194
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of numpy functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class NumPyParameterChecker(BaseChecker):
14+
name = "numpy-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in numpy method %s.",
18+
"numpy-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
REQUIRED_PARAMS = {
24+
# Array Creation
25+
'array': ['object'],
26+
'zeros': ['shape'],
27+
'ones': ['shape'],
28+
'full': ['shape', 'fill_value'],
29+
'empty': ['shape'],
30+
'arange': ['start'],
31+
'linspace': ['start', 'stop'],
32+
'logspace': ['start', 'stop'],
33+
'eye': ['N'],
34+
'identity': ['n'],
35+
36+
# Random Sampling
37+
'random.rand': ['d0'],
38+
'random.randn': ['d0'],
39+
'random.randint': ['low', 'high'],
40+
'random.choice': ['a'],
41+
'random.uniform': ['low', 'high'],
42+
'random.normal': ['loc', 'scale'],
43+
44+
# Mathematical Functions
45+
'sum': ['a'],
46+
'mean': ['a'],
47+
'median': ['a'],
48+
'std': ['a'],
49+
'var': ['a'],
50+
'prod': ['a'],
51+
'min': ['a'],
52+
'max': ['a'],
53+
'ptp': ['a'],
54+
55+
# Array Manipulation
56+
'reshape': ['newshape'],
57+
'transpose': [],
58+
'concatenate': ['arrays'],
59+
'stack': ['arrays'],
60+
'vstack': ['arrays'],
61+
'hstack': ['arrays'],
62+
63+
# Linear Algebra
64+
'dot': ['a', 'b'],
65+
'matmul': ['a', 'b'],
66+
'linalg.inv': ['a'],
67+
'linalg.eig': ['a'],
68+
'linalg.solve': ['a', 'b'],
69+
70+
# Statistical Functions
71+
'percentile': ['a', 'q'],
72+
'quantile': ['a', 'q'],
73+
'corrcoef': ['x'],
74+
'cov': ['m'],
75+
}
76+
77+
@only_required_for_messages("numpy-parameter")
78+
def visit_call(self, node: nodes.Call) -> None:
79+
method_name = self._get_full_method_name(node)
80+
81+
if method_name in self.REQUIRED_PARAMS:
82+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
83+
# Collect all missing parameters
84+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
85+
if missing_params:
86+
self.add_message(
87+
"numpy-parameter",
88+
node=node,
89+
confidence=HIGH,
90+
args=(", ".join(missing_params), method_name),
91+
)
92+
93+
@staticmethod
94+
def _get_full_method_name(node: nodes.Call) -> str:
95+
"""
96+
Extracts the full method name, including chained attributes (e.g., np.random.rand).
97+
"""
98+
func = node.func
99+
method_chain = []
100+
101+
# Traverse the attribute chain
102+
while isinstance(func, nodes.Attribute):
103+
method_chain.insert(0, func.attrname)
104+
func = func.expr
105+
106+
# Check if the root of the chain is "np" (as NumPy functions are expected to use np. prefix)
107+
if isinstance(func, nodes.Name) and func.name == "np":
108+
return '.'.join(method_chain)
109+
return ""
110+

pylint_ml/checkers/pandas/pandas_dataframe_merge.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

pylint_ml/checkers/pandas/pandas_dtype_param.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Pandas functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class PandasParameterChecker(BaseChecker):
14+
name = "pandas-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in Pandas method %s.",
18+
"pandas-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
# Define required parameters for specific Pandas classes and methods
24+
REQUIRED_PARAMS = {
25+
# DataFrame creation
26+
'DataFrame': ['data'], # The primary input data for DataFrame creation
27+
28+
# Concatenation
29+
'concat': ['objs'], # The list or dictionary of DataFrames/Series to concatenate
30+
31+
# DataFrame I/O (Input/Output)
32+
'read_csv': ['filepath_or_buffer', 'dtype'], # Path to CSV file or file-like object; column data types
33+
'read_excel': ['io', 'dtype'], # Path to Excel file or file-like object; column data types
34+
'read_table': ['filepath_or_buffer', 'dtype'], # Path to delimited text-file or file object; column data types
35+
'to_csv': ['path_or_buf'], # File path or buffer to write the DataFrame to
36+
'to_excel': ['excel_writer'], # File path or ExcelWriter object to write the data to
37+
38+
# Merging and Joining
39+
'merge': ['right', 'how', 'on', 'validate'], # The DataFrame or Serie to merge with
40+
'join': ['other'], # The DataFrame or Series to join
41+
42+
# DataFrame Operations
43+
'pivot_table': ['index'], # The column to pivot on (values and columns have defaults)
44+
'groupby': ['by'], # The key or list of keys to group by
45+
'resample': ['rule'], # The frequency rule to resample by
46+
47+
# Data Cleaning and Transformation
48+
'fillna': ['value'], # Value to use to fill NA/NaN values
49+
'drop': ['labels'], # Labels to drop
50+
'drop_duplicates': ['subset'], # Subset of columns to consider when dropping duplicates
51+
'replace': ['to_replace'], # Values to replace
52+
53+
# Plotting
54+
'plot': ['x'], # x-values or index for plotting
55+
'hist': ['column'], # Column to plot the histogram for
56+
'boxplot': ['column'], # Column(s) to plot boxplot for
57+
58+
# DataFrame Sorting
59+
'sort_values': ['by'], # Column(s) to sort by
60+
'sort_index': ['axis'], # Axis to sort along (index=0, columns=1)
61+
62+
# Statistical Functions
63+
'corr': ['method'], # Method to use for correlation ('pearson', 'kendall', 'spearman')
64+
'describe': [], # No required parameters, but additional ones could be specified
65+
66+
# Windowing/Resampling Functions
67+
'rolling': ['window'], # Size of the moving window
68+
'ewm': ['span'], # Span for exponentially weighted calculations
69+
70+
# Miscellaneous Functions
71+
'apply': ['func'], # Function to apply to the data
72+
'agg': ['func'], # Function or list of functions for aggregation
73+
}
74+
75+
@only_required_for_messages("pandas-parameter")
76+
def visit_call(self, node: nodes.Call) -> None:
77+
method_name = self._get_method_name(node)
78+
if method_name in self.REQUIRED_PARAMS:
79+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
80+
# Collect all missing parameters
81+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
82+
if missing_params:
83+
self.add_message(
84+
"pandas-parameter",
85+
node=node,
86+
confidence=HIGH,
87+
args=(", ".join(missing_params), method_name),
88+
)
89+
90+
@staticmethod
91+
def _get_method_name(node: nodes.Call) -> str:
92+
"""Extracts the method name from a Call node, including handling chained calls."""
93+
func = node.func
94+
while isinstance(func, nodes.Attribute):
95+
func = func.expr
96+
return (
97+
node.func.attrname
98+
if isinstance(node.func, nodes.Attribute)
99+
else func.name if isinstance(func, nodes.Name) else ""
100+
)

pylint_ml/checkers/scipy/scipy_import.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ class ScipyImportChecker(BaseChecker):
2525
),
2626
}
2727

28-
@only_required_for_messages("scipy-import")
28+
@only_required_for_messages("scipy-import", "scipy-wildcard-import")
2929
def visit_import(self, node: nodes.Import) -> None:
30-
for name, _ in node.names:
30+
for name, _alias in node.names:
3131
if name == "scipy":
32+
# Flag direct or aliased imports of scipy
3233
self.add_message("scipy-import", node=node, confidence=HIGH)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Scipy functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class ScipyParameterChecker(BaseChecker):
14+
name = "scipy-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in scipy method %s.",
18+
"scipy-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
# Define required parameters for specific Scipy classes and methods
24+
REQUIRED_PARAMS = {
25+
# scipy.optimize
26+
'minimize': ['fun', 'x0'],
27+
'curve_fit': ['f', 'xdata', 'ydata'],
28+
'root': ['fun', 'x0'],
29+
30+
# scipy.integrate
31+
'quad': ['func', 'a', 'b'],
32+
'dblquad': ['func', 'a', 'b', 'gfun', 'hfun'],
33+
'solve_ivp': ['fun', 't_span', 'y0'],
34+
35+
# scipy.stats
36+
'ttest_ind': ['a', 'b'],
37+
'ttest_rel': ['a', 'b'],
38+
'norm.pdf': ['x'],
39+
40+
# scipy.spatial
41+
'distance.euclidean': ['u', 'v'], # Full chain
42+
'euclidean': ['u', 'v'], # Direct import of euclidean
43+
'KDTree.query': ['x'],
44+
}
45+
46+
@only_required_for_messages("scipy-parameter")
47+
def visit_call(self, node: nodes.Call) -> None:
48+
method_name = self._get_full_method_name(node)
49+
if method_name in self.REQUIRED_PARAMS:
50+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
51+
# Collect all missing parameters
52+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
53+
if missing_params:
54+
self.add_message(
55+
"scipy-parameter",
56+
node=node,
57+
confidence=HIGH,
58+
args=(", ".join(missing_params), method_name),
59+
)
60+
61+
def _get_full_method_name(self, node: nodes.Call) -> str:
62+
"""
63+
Extracts the full method name, including handling chained attributes (e.g., scipy.spatial.distance.euclidean)
64+
and also handles direct imports like euclidean.
65+
"""
66+
func = node.func
67+
method_chain = []
68+
69+
# Traverse the attribute chain to get the full method name
70+
while isinstance(func, nodes.Attribute):
71+
method_chain.insert(0, func.attrname)
72+
func = func.expr
73+
74+
# If it's a direct function name, like `euclidean`, return it
75+
if isinstance(func, nodes.Name):
76+
method_chain.insert(0, func.name)
77+
78+
return '.'.join(method_chain)

0 commit comments

Comments
 (0)