Skip to content

Commit 5a130af

Browse files
authored
Merge pull request #52 from pylint-dev/29-add-sklearn-kmeans-hyperparameter-checker
Add parameter checker for torch, tensor and sklearn
2 parents b8e9583 + 1eef386 commit 5a130af

File tree

6 files changed

+227
-54
lines changed

6 files changed

+227
-54
lines changed

CONTRIBUTORS.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Maintainers
2+
-----------
3+
Peter Hamfelt
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Scikit-learn functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class SklearnParameterChecker(BaseChecker):
14+
name = "sklearn-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in Sklearn method %s.",
18+
"sklearn-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
20+
),
21+
}
22+
23+
# Define required parameters for specific scikit-learn classes and methods
24+
REQUIRED_PARAMS = {
25+
# Model Creation and Initialization
26+
"RandomForestClassifier": ["n_estimators"], # Number of trees in the forest is crucial
27+
"SVC": ["C", "kernel"], # Regularization parameter and kernel type are essential
28+
"LogisticRegression": ["penalty", "C"], # Regularization penalty and strength are critical
29+
"KMeans": ["n_clusters"], # Number of clusters to form is a core parameter
30+
# Model Training
31+
"fit": ["X", "y"], # Input data (X) and target labels (y) are required for training
32+
# Cross-Validation
33+
"cross_val_score": ["estimator", "X"], # Estimator and input data are essential for cross-validation
34+
# Grid Search
35+
"GridSearchCV": ["estimator", "param_grid"], # Estimator and parameter grid are crucial for grid search
36+
}
37+
38+
@only_required_for_messages("sklearn-parameter")
39+
def visit_call(self, node: nodes.Call) -> None:
40+
method_name = self._get_method_name(node)
41+
if method_name in self.REQUIRED_PARAMS:
42+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
43+
# Collect all missing parameters
44+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
45+
if missing_params:
46+
self.add_message(
47+
"sklearn-parameter",
48+
node=node,
49+
confidence=HIGH,
50+
args=(", ".join(missing_params), method_name),
51+
)
52+
53+
@staticmethod
54+
def _get_method_name(node: nodes.Call) -> str:
55+
"""Extracts the method name from a Call node, including handling chained calls."""
56+
func = node.func
57+
while isinstance(func, nodes.Attribute):
58+
func = func.expr
59+
return (
60+
node.func.attrname
61+
if isinstance(node.func, nodes.Attribute)
62+
else func.name if isinstance(func, nodes.Name) else ""
63+
)

pylint_ml/checkers/tensorflow/tensor_parameter.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TensorFlowParameterChecker(BaseChecker):
1616
"W8111": (
1717
"Ensure that required parameters %s are explicitly specified in TensorFlow method %s.",
1818
"tensor-parameter",
19-
"Explicitly specifying required parameters improves model performance and prevents unintended " "behavior.",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
2020
),
2121
}
2222

@@ -35,39 +35,27 @@ class TensorFlowParameterChecker(BaseChecker):
3535

3636
@only_required_for_messages("tensor-parameter")
3737
def visit_call(self, node: nodes.Call) -> None:
38-
if isinstance(node.func, nodes.Attribute):
39-
method_name = node.func.attrname
40-
if method_name in self.REQUIRED_PARAMS:
41-
required_params = self.REQUIRED_PARAMS[method_name]
42-
# Check for explicit parameters
43-
missing_params = [
44-
param for param in required_params if not any(kw.arg == param for kw in node.keywords)
45-
]
46-
47-
if missing_params:
48-
self.add_message(
49-
"tensor-parameter",
50-
node=node,
51-
confidence=HIGH,
52-
args=(", ".join(missing_params), method_name),
53-
)
54-
55-
@only_required_for_messages("tensor-parameter")
56-
def visit_call(self, node: nodes.Call) -> None:
57-
if isinstance(node.func, nodes.Attribute):
58-
method_name = node.func.attrname
59-
if method_name in self.REQUIRED_PARAMS:
60-
required_params = self.REQUIRED_PARAMS[method_name]
61-
# Extract all provided keyword arguments
62-
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
63-
64-
# Check if required parameters are provided explicitly as keyword arguments
65-
missing_params = [param for param in required_params if param not in provided_keywords]
66-
67-
if missing_params:
68-
self.add_message(
69-
"tensor-parameter",
70-
node=node,
71-
confidence=HIGH,
72-
args=(", ".join(missing_params), method_name),
73-
)
38+
method_name = self._get_method_name(node)
39+
if method_name in self.REQUIRED_PARAMS:
40+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
41+
# Collect all missing parameters
42+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
43+
if missing_params:
44+
self.add_message(
45+
"tensor-parameter",
46+
node=node,
47+
confidence=HIGH,
48+
args=(", ".join(missing_params), method_name),
49+
)
50+
51+
@staticmethod
52+
def _get_method_name(node: nodes.Call) -> str:
53+
"""Extracts the method name from a Call node, including handling chained calls."""
54+
func = node.func
55+
while isinstance(func, nodes.Attribute):
56+
func = func.expr
57+
return (
58+
node.func.attrname
59+
if isinstance(node.func, nodes.Attribute)
60+
else func.name if isinstance(func, nodes.Name) else ""
61+
)

pylint_ml/checkers/torch/torch_parameter.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class PyTorchParameterChecker(BaseChecker):
1616
"W8111": (
1717
"Ensure that required parameters %s are explicitly specified in PyTorch method %s.",
1818
"pytorch-parameter",
19-
"Explicitly specifying required parameters improves model performance and prevents unintended " "behavior.",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended behavior.",
2020
),
2121
}
2222

@@ -34,19 +34,27 @@ class PyTorchParameterChecker(BaseChecker):
3434

3535
@only_required_for_messages("pytorch-parameter")
3636
def visit_call(self, node: nodes.Call) -> None:
37-
if isinstance(node.func, nodes.Attribute):
38-
method_name = node.func.attrname
39-
if method_name in self.REQUIRED_PARAMS:
40-
required_params = self.REQUIRED_PARAMS[method_name]
41-
# Check for explicit parameters
42-
missing_params = [
43-
param for param in required_params if not any(kw.arg == param for kw in node.keywords)
44-
]
37+
method_name = self._get_method_name(node)
38+
if method_name in self.REQUIRED_PARAMS:
39+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
40+
# Collect all missing parameters
41+
missing_params = [param for param in self.REQUIRED_PARAMS[method_name] if param not in provided_keywords]
42+
if missing_params:
43+
self.add_message(
44+
"pytorch-parameter",
45+
node=node,
46+
confidence=HIGH,
47+
args=(", ".join(missing_params), method_name),
48+
)
4549

46-
if missing_params:
47-
self.add_message(
48-
"pytorch-parameter",
49-
node=node,
50-
confidence=HIGH,
51-
args=(", ".join(missing_params), method_name),
52-
)
50+
@staticmethod
51+
def _get_method_name(node: nodes.Call) -> str:
52+
"""Extracts the method name from a Call node, including handling chained calls."""
53+
func = node.func
54+
while isinstance(func, nodes.Attribute):
55+
func = func.expr
56+
return (
57+
node.func.attrname
58+
if isinstance(node.func, nodes.Attribute)
59+
else func.name if isinstance(func, nodes.Name) else ""
60+
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.sklearn.sklearn_parameter import SklearnParameterChecker
6+
7+
8+
class TestSklearnParameterChecker(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = SklearnParameterChecker
10+
11+
def test_random_forest_params(self):
12+
node = astroid.extract_node(
13+
"""
14+
from sklearn.ensemble import RandomForestClassifier
15+
clf = RandomForestClassifier() # [sklearn-parameter]
16+
"""
17+
)
18+
19+
forest_call = node.value
20+
21+
with self.assertAddsMessages(
22+
pylint.testutils.MessageTest(
23+
msg_id="sklearn-parameter",
24+
confidence=HIGH,
25+
node=forest_call,
26+
args=("n_estimators", "RandomForestClassifier"),
27+
),
28+
ignore_position=True,
29+
):
30+
self.checker.visit_call(forest_call)
31+
32+
def test_random_forest_with_params(self):
33+
node = astroid.extract_node(
34+
"""
35+
from sklearn.ensemble import RandomForestClassifier
36+
clf = RandomForestClassifier(n_estimators=100) # Should not trigger
37+
"""
38+
)
39+
40+
forest_call = node.value
41+
42+
with self.assertNoMessages():
43+
self.checker.visit_call(forest_call)
44+
45+
def test_svc_params(self):
46+
node = astroid.extract_node(
47+
"""
48+
from sklearn.svm import SVC
49+
clf = SVC() # [sklearn-parameter]
50+
"""
51+
)
52+
53+
svc_call = node.value
54+
55+
with self.assertAddsMessages(
56+
pylint.testutils.MessageTest(
57+
msg_id="sklearn-parameter",
58+
confidence=HIGH,
59+
node=svc_call,
60+
args=("C, kernel", "SVC"),
61+
),
62+
ignore_position=True,
63+
):
64+
self.checker.visit_call(svc_call)
65+
66+
def test_svc_with_params(self):
67+
node = astroid.extract_node(
68+
"""
69+
from sklearn.svm import SVC
70+
clf = SVC(C=1.0, kernel='linear') # Should not trigger
71+
"""
72+
)
73+
74+
svc_call = node.value
75+
76+
with self.assertNoMessages():
77+
self.checker.visit_call(svc_call)
78+
79+
def test_kmeans_params(self):
80+
node = astroid.extract_node(
81+
"""
82+
from sklearn.cluster import KMeans
83+
kmeans = KMeans() # [sklearn-parameter]
84+
"""
85+
)
86+
87+
kmeans_call = node.value
88+
89+
with self.assertAddsMessages(
90+
pylint.testutils.MessageTest(
91+
msg_id="sklearn-parameter",
92+
confidence=HIGH,
93+
node=kmeans_call,
94+
args=("n_clusters", "KMeans"),
95+
),
96+
ignore_position=True,
97+
):
98+
self.checker.visit_call(kmeans_call)
99+
100+
def test_kmeans_with_params(self):
101+
node = astroid.extract_node(
102+
"""
103+
from sklearn.cluster import KMeans
104+
kmeans = KMeans(n_clusters=8) # Should not trigger
105+
"""
106+
)
107+
108+
kmeans_call = node.value
109+
110+
with self.assertNoMessages():
111+
self.checker.visit_call(kmeans_call)

tests/checkers/test_tensorflow/test_tensor_parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_compile_with_all_params(self):
7070
"""
7171
import tensorflow as tf
7272
model = tf.keras.models.Sequential()
73-
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Should not trigger
73+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # No trigger
7474
"""
7575
)
7676

0 commit comments

Comments
 (0)