Skip to content

Commit 959b8f1

Browse files
author
Peter Hamfelt
committed
Add parameter checkers
1 parent 6c7e5a5 commit 959b8f1

File tree

11 files changed

+497
-98
lines changed

11 files changed

+497
-98
lines changed

pylint_ml/checkers/sklearn/sklearn_parameter.py

Whitespace-only changes.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of Tensorflow functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class TensorFlowParameterChecker(BaseChecker):
14+
name = "tensor-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in TensorFlow method %s.",
18+
"tensor-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended "
20+
"behavior.",
21+
),
22+
}
23+
24+
# Define required parameters for specific TensorFlow methods
25+
REQUIRED_PARAMS = {
26+
# Model Creation
27+
'Sequential': ['layers'], # Layers must be specified to build a model
28+
29+
# Model Compilation
30+
'compile': ['optimizer', 'loss'], # Optimizer and loss function are essential for training
31+
32+
# Model Training
33+
'fit': ['x', 'y'], # Input data (x) and target data (y) are required to train the model
34+
35+
# Layers
36+
'Conv2D': ['filters', 'kernel_size'], # Filters and kernel size define the convolutional layer's structure
37+
'Dense': ['units'], # Number of units (neurons) is crucial for a Dense layer
38+
}
39+
40+
@only_required_for_messages("tensor-parameter")
41+
def visit_call(self, node: nodes.Call) -> None:
42+
if isinstance(node.func, nodes.Attribute):
43+
method_name = node.func.attrname
44+
if method_name in self.REQUIRED_PARAMS:
45+
required_params = self.REQUIRED_PARAMS[method_name]
46+
# Check for explicit parameters
47+
missing_params = [param for param in required_params if
48+
not any(kw.arg == param for kw in node.keywords)]
49+
50+
if missing_params:
51+
self.add_message(
52+
"tensor-parameter",
53+
node=node,
54+
confidence=HIGH,
55+
args=(', '.join(missing_params), method_name),
56+
)
57+
58+
@only_required_for_messages("tensor-parameter")
59+
def visit_call(self, node: nodes.Call) -> None:
60+
if isinstance(node.func, nodes.Attribute):
61+
method_name = node.func.attrname
62+
if method_name in self.REQUIRED_PARAMS:
63+
required_params = self.REQUIRED_PARAMS[method_name]
64+
# Extract all provided keyword arguments
65+
provided_keywords = {kw.arg for kw in node.keywords if kw.arg is not None}
66+
67+
# Check if required parameters are provided explicitly as keyword arguments
68+
missing_params = [param for param in required_params if param not in provided_keywords]
69+
70+
if missing_params:
71+
self.add_message(
72+
"tensor-parameter",
73+
node=node,
74+
confidence=HIGH,
75+
args=(', '.join(missing_params), method_name),
76+
)

pylint_ml/checkers/torch/torch_import.py

Lines changed: 0 additions & 43 deletions
This file was deleted.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Licensed under the MIT: https://mit-license.org/
2+
# For details: https://github.com/pylint-dev/pylint-ml/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/pylint-ml/CONTRIBUTORS.txt
4+
5+
"""Check for proper usage of PyTorch functions with required parameters."""
6+
7+
from astroid import nodes
8+
from pylint.checkers import BaseChecker
9+
from pylint.checkers.utils import only_required_for_messages
10+
from pylint.interfaces import HIGH
11+
12+
13+
class PyTorchParameterChecker(BaseChecker):
14+
name = "pytorch-parameter"
15+
msgs = {
16+
"W8111": (
17+
"Ensure that required parameters %s are explicitly specified in PyTorch method %s.",
18+
"pytorch-parameter",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended "
20+
"behavior.",
21+
),
22+
}
23+
24+
# Define required parameters for specific PyTorch methods
25+
REQUIRED_PARAMS = {
26+
# Optimizers
27+
'SGD': ['lr'], # Focus on the critical learning rate parameter
28+
'Adam': ['lr'], # Learning rate is typically the most important for tuning
29+
30+
# Layers
31+
'Conv2d': ['in_channels', 'out_channels', 'kernel_size'],
32+
# These parameters define the convolution's core operation
33+
'Linear': ['in_features', 'out_features'], # Essential to define the transformation dimensions
34+
'LSTM': ['input_size', 'hidden_size'], # Essential for defining the dimensionality of the LSTM cell
35+
}
36+
37+
@only_required_for_messages("pytorch-parameter")
38+
def visit_call(self, node: nodes.Call) -> None:
39+
if isinstance(node.func, nodes.Attribute):
40+
method_name = node.func.attrname
41+
if method_name in self.REQUIRED_PARAMS:
42+
required_params = self.REQUIRED_PARAMS[method_name]
43+
# Check for explicit parameters
44+
missing_params = [param for param in required_params if
45+
not any(kw.arg == param for kw in node.keywords)]
46+
47+
if missing_params:
48+
self.add_message(
49+
"pytorch-parameter",
50+
node=node,
51+
confidence=HIGH,
52+
args=(', '.join(missing_params), method_name),
53+
)

pylint_ml/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pylint_ml.checkers.scipy.scipy_import import ScipyImportChecker
99
from pylint_ml.checkers.sklearn.sklearn_import import SklearnImportChecker
1010
from pylint_ml.checkers.tensorflow.tensorflow_import import TensorflowImportChecker
11-
from pylint_ml.checkers.torch.torch_import import TorchImportChecker
11+
from pylint_ml.checkers.torch.torch_parameter import PyTorchParameterChecker
1212

1313

1414
def register(linter: PyLinter) -> None:
@@ -24,7 +24,7 @@ def register(linter: PyLinter) -> None:
2424
linter.register_checker(TensorflowImportChecker(linter))
2525

2626
# Torch
27-
linter.register_checker(TorchImportChecker(linter))
27+
linter.register_checker(PyTorchParameterChecker(linter))
2828

2929
# Scipy
3030
linter.register_checker(ScipyImportChecker(linter))
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import astroid
2+
import pylint.testutils
3+
from pylint.interfaces import HIGH
4+
5+
from pylint_ml.checkers.tensorflow.tensor_parameter import TensorFlowParameterChecker
6+
7+
8+
class TestTensorParameterChecker(pylint.testutils.CheckerTestCase):
9+
CHECKER_CLASS = TensorFlowParameterChecker
10+
11+
def test_sequential_params(self):
12+
node = astroid.extract_node(
13+
"""
14+
import tensorflow as tf
15+
model = tf.keras.models.Sequential() # [tensor-parameter]
16+
"""
17+
)
18+
19+
sequential_call = node.value
20+
21+
with self.assertAddsMessages(
22+
pylint.testutils.MessageTest(
23+
msg_id="tensor-parameter",
24+
confidence=HIGH,
25+
node=sequential_call,
26+
args=("layers", "Sequential"),
27+
),
28+
ignore_position=True,
29+
):
30+
self.checker.visit_call(sequential_call)
31+
32+
def test_sequential_with_layers(self):
33+
node = astroid.extract_node(
34+
"""
35+
import tensorflow as tf
36+
model = tf.keras.Sequential(layers=[
37+
tf.keras.layers.Dense(units=64, activation='relu'),
38+
tf.keras.layers.Dense(units=10)
39+
])
40+
"""
41+
)
42+
43+
sequential_call = node.value
44+
45+
with self.assertNoMessages():
46+
self.checker.visit_call(sequential_call)
47+
48+
def test_compile_params(self):
49+
node = astroid.extract_node(
50+
"""
51+
import tensorflow as tf
52+
model = tf.keras.models.Sequential()
53+
model.compile() # [tensor-parameter]
54+
"""
55+
)
56+
57+
with self.assertAddsMessages(
58+
pylint.testutils.MessageTest(
59+
msg_id="tensor-parameter",
60+
confidence=HIGH,
61+
node=node,
62+
args=("optimizer, loss", "compile"),
63+
),
64+
ignore_position=True,
65+
):
66+
self.checker.visit_call(node)
67+
68+
def test_compile_with_all_params(self):
69+
node = astroid.extract_node(
70+
"""
71+
import tensorflow as tf
72+
model = tf.keras.models.Sequential()
73+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Should not trigger
74+
"""
75+
)
76+
77+
compile_call = node
78+
79+
with self.assertNoMessages():
80+
self.checker.visit_call(compile_call)
81+
82+
def test_fit_params(self):
83+
node = astroid.extract_node(
84+
"""
85+
import tensorflow as tf
86+
model = tf.keras.models.Sequential()
87+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
88+
model.fit(epochs=10) # [tensor-parameter]
89+
"""
90+
)
91+
92+
fit_call = node
93+
94+
with self.assertAddsMessages(
95+
pylint.testutils.MessageTest(
96+
msg_id="tensor-parameter",
97+
confidence=HIGH,
98+
node=fit_call,
99+
args=("x, y", "fit"),
100+
),
101+
ignore_position=True,
102+
):
103+
self.checker.visit_call(fit_call)
104+
105+
def test_fit_with_all_params(self):
106+
node = astroid.extract_node(
107+
"""
108+
import tensorflow as tf
109+
model = tf.keras.models.Sequential()
110+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
111+
model.fit(x=train_data, y=train_labels, epochs=10) # Should not trigger
112+
"""
113+
)
114+
115+
fit_call = node
116+
117+
with self.assertNoMessages():
118+
self.checker.visit_call(fit_call)
119+
120+
def test_conv2d_params(self):
121+
node = astroid.extract_node(
122+
"""
123+
import tensorflow as tf
124+
layer = tf.keras.layers.Conv2D(kernel_size=(3, 3)) # [tensor-parameter]
125+
"""
126+
)
127+
128+
conv2d_call = node.value
129+
130+
with self.assertAddsMessages(
131+
pylint.testutils.MessageTest(
132+
msg_id="tensor-parameter",
133+
confidence=HIGH,
134+
node=conv2d_call,
135+
args=("filters", "Conv2D"),
136+
),
137+
ignore_position=True,
138+
):
139+
self.checker.visit_call(conv2d_call)
140+
141+
def test_conv2d_with_all_params(self):
142+
node = astroid.extract_node(
143+
"""
144+
import tensorflow as tf
145+
layer = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)) # Should not trigger
146+
"""
147+
)
148+
149+
conv2d_call = node.value
150+
151+
with self.assertNoMessages():
152+
self.checker.visit_call(conv2d_call)
153+
154+
def test_dense_params(self):
155+
node = astroid.extract_node(
156+
"""
157+
import tensorflow as tf
158+
layer = tf.keras.layers.Dense() # [tensor-parameter]
159+
"""
160+
)
161+
162+
dense_call = node.value
163+
164+
with self.assertAddsMessages(
165+
pylint.testutils.MessageTest(
166+
msg_id="tensor-parameter",
167+
confidence=HIGH,
168+
node=dense_call,
169+
args=("units", "Dense"),
170+
),
171+
ignore_position=True,
172+
):
173+
self.checker.visit_call(dense_call)
174+
175+
def test_dense_with_all_params(self):
176+
node = astroid.extract_node(
177+
"""
178+
import tensorflow as tf
179+
layer = tf.keras.layers.Dense(units=64) # Should not trigger
180+
"""
181+
)
182+
183+
dense_call = node.value
184+
185+
with self.assertNoMessages():
186+
self.checker.visit_call(dense_call)
187+

0 commit comments

Comments
 (0)