Skip to content

Commit 1da3a27

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 959b8f1 commit 1da3a27

File tree

4 files changed

+91
-96
lines changed

4 files changed

+91
-96
lines changed

pylint_ml/checkers/tensorflow/tensor_parameter.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,21 @@ class TensorFlowParameterChecker(BaseChecker):
1616
"W8111": (
1717
"Ensure that required parameters %s are explicitly specified in TensorFlow method %s.",
1818
"tensor-parameter",
19-
"Explicitly specifying required parameters improves model performance and prevents unintended "
20-
"behavior.",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended " "behavior.",
2120
),
2221
}
2322

2423
# Define required parameters for specific TensorFlow methods
2524
REQUIRED_PARAMS = {
2625
# Model Creation
27-
'Sequential': ['layers'], # Layers must be specified to build a model
28-
26+
"Sequential": ["layers"], # Layers must be specified to build a model
2927
# Model Compilation
30-
'compile': ['optimizer', 'loss'], # Optimizer and loss function are essential for training
31-
28+
"compile": ["optimizer", "loss"], # Optimizer and loss function are essential for training
3229
# Model Training
33-
'fit': ['x', 'y'], # Input data (x) and target data (y) are required to train the model
34-
30+
"fit": ["x", "y"], # Input data (x) and target data (y) are required to train the model
3531
# Layers
36-
'Conv2D': ['filters', 'kernel_size'], # Filters and kernel size define the convolutional layer's structure
37-
'Dense': ['units'], # Number of units (neurons) is crucial for a Dense layer
32+
"Conv2D": ["filters", "kernel_size"], # Filters and kernel size define the convolutional layer's structure
33+
"Dense": ["units"], # Number of units (neurons) is crucial for a Dense layer
3834
}
3935

4036
@only_required_for_messages("tensor-parameter")
@@ -44,15 +40,16 @@ def visit_call(self, node: nodes.Call) -> None:
4440
if method_name in self.REQUIRED_PARAMS:
4541
required_params = self.REQUIRED_PARAMS[method_name]
4642
# Check for explicit parameters
47-
missing_params = [param for param in required_params if
48-
not any(kw.arg == param for kw in node.keywords)]
43+
missing_params = [
44+
param for param in required_params if not any(kw.arg == param for kw in node.keywords)
45+
]
4946

5047
if missing_params:
5148
self.add_message(
5249
"tensor-parameter",
5350
node=node,
5451
confidence=HIGH,
55-
args=(', '.join(missing_params), method_name),
52+
args=(", ".join(missing_params), method_name),
5653
)
5754

5855
@only_required_for_messages("tensor-parameter")
@@ -72,5 +69,5 @@ def visit_call(self, node: nodes.Call) -> None:
7269
"tensor-parameter",
7370
node=node,
7471
confidence=HIGH,
75-
args=(', '.join(missing_params), method_name),
72+
args=(", ".join(missing_params), method_name),
7673
)

pylint_ml/checkers/torch/torch_parameter.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,20 @@ class PyTorchParameterChecker(BaseChecker):
1616
"W8111": (
1717
"Ensure that required parameters %s are explicitly specified in PyTorch method %s.",
1818
"pytorch-parameter",
19-
"Explicitly specifying required parameters improves model performance and prevents unintended "
20-
"behavior.",
19+
"Explicitly specifying required parameters improves model performance and prevents unintended " "behavior.",
2120
),
2221
}
2322

2423
# Define required parameters for specific PyTorch methods
2524
REQUIRED_PARAMS = {
2625
# Optimizers
27-
'SGD': ['lr'], # Focus on the critical learning rate parameter
28-
'Adam': ['lr'], # Learning rate is typically the most important for tuning
29-
26+
"SGD": ["lr"], # Focus on the critical learning rate parameter
27+
"Adam": ["lr"], # Learning rate is typically the most important for tuning
3028
# Layers
31-
'Conv2d': ['in_channels', 'out_channels', 'kernel_size'],
29+
"Conv2d": ["in_channels", "out_channels", "kernel_size"],
3230
# These parameters define the convolution's core operation
33-
'Linear': ['in_features', 'out_features'], # Essential to define the transformation dimensions
34-
'LSTM': ['input_size', 'hidden_size'], # Essential for defining the dimensionality of the LSTM cell
31+
"Linear": ["in_features", "out_features"], # Essential to define the transformation dimensions
32+
"LSTM": ["input_size", "hidden_size"], # Essential for defining the dimensionality of the LSTM cell
3533
}
3634

3735
@only_required_for_messages("pytorch-parameter")
@@ -41,13 +39,14 @@ def visit_call(self, node: nodes.Call) -> None:
4139
if method_name in self.REQUIRED_PARAMS:
4240
required_params = self.REQUIRED_PARAMS[method_name]
4341
# Check for explicit parameters
44-
missing_params = [param for param in required_params if
45-
not any(kw.arg == param for kw in node.keywords)]
42+
missing_params = [
43+
param for param in required_params if not any(kw.arg == param for kw in node.keywords)
44+
]
4645

4746
if missing_params:
4847
self.add_message(
4948
"pytorch-parameter",
5049
node=node,
5150
confidence=HIGH,
52-
args=(', '.join(missing_params), method_name),
51+
args=(", ".join(missing_params), method_name),
5352
)

tests/checkers/test_tensorflow/test_tensor_parameter.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def test_sequential_params(self):
1919
sequential_call = node.value
2020

2121
with self.assertAddsMessages(
22-
pylint.testutils.MessageTest(
23-
msg_id="tensor-parameter",
24-
confidence=HIGH,
25-
node=sequential_call,
26-
args=("layers", "Sequential"),
27-
),
28-
ignore_position=True,
22+
pylint.testutils.MessageTest(
23+
msg_id="tensor-parameter",
24+
confidence=HIGH,
25+
node=sequential_call,
26+
args=("layers", "Sequential"),
27+
),
28+
ignore_position=True,
2929
):
3030
self.checker.visit_call(sequential_call)
3131

@@ -55,13 +55,13 @@ def test_compile_params(self):
5555
)
5656

5757
with self.assertAddsMessages(
58-
pylint.testutils.MessageTest(
59-
msg_id="tensor-parameter",
60-
confidence=HIGH,
61-
node=node,
62-
args=("optimizer, loss", "compile"),
63-
),
64-
ignore_position=True,
58+
pylint.testutils.MessageTest(
59+
msg_id="tensor-parameter",
60+
confidence=HIGH,
61+
node=node,
62+
args=("optimizer, loss", "compile"),
63+
),
64+
ignore_position=True,
6565
):
6666
self.checker.visit_call(node)
6767

@@ -92,13 +92,13 @@ def test_fit_params(self):
9292
fit_call = node
9393

9494
with self.assertAddsMessages(
95-
pylint.testutils.MessageTest(
96-
msg_id="tensor-parameter",
97-
confidence=HIGH,
98-
node=fit_call,
99-
args=("x, y", "fit"),
100-
),
101-
ignore_position=True,
95+
pylint.testutils.MessageTest(
96+
msg_id="tensor-parameter",
97+
confidence=HIGH,
98+
node=fit_call,
99+
args=("x, y", "fit"),
100+
),
101+
ignore_position=True,
102102
):
103103
self.checker.visit_call(fit_call)
104104

@@ -128,13 +128,13 @@ def test_conv2d_params(self):
128128
conv2d_call = node.value
129129

130130
with self.assertAddsMessages(
131-
pylint.testutils.MessageTest(
132-
msg_id="tensor-parameter",
133-
confidence=HIGH,
134-
node=conv2d_call,
135-
args=("filters", "Conv2D"),
136-
),
137-
ignore_position=True,
131+
pylint.testutils.MessageTest(
132+
msg_id="tensor-parameter",
133+
confidence=HIGH,
134+
node=conv2d_call,
135+
args=("filters", "Conv2D"),
136+
),
137+
ignore_position=True,
138138
):
139139
self.checker.visit_call(conv2d_call)
140140

@@ -162,13 +162,13 @@ def test_dense_params(self):
162162
dense_call = node.value
163163

164164
with self.assertAddsMessages(
165-
pylint.testutils.MessageTest(
166-
msg_id="tensor-parameter",
167-
confidence=HIGH,
168-
node=dense_call,
169-
args=("units", "Dense"),
170-
),
171-
ignore_position=True,
165+
pylint.testutils.MessageTest(
166+
msg_id="tensor-parameter",
167+
confidence=HIGH,
168+
node=dense_call,
169+
args=("units", "Dense"),
170+
),
171+
ignore_position=True,
172172
):
173173
self.checker.visit_call(dense_call)
174174

@@ -184,4 +184,3 @@ def test_dense_with_all_params(self):
184184

185185
with self.assertNoMessages():
186186
self.checker.visit_call(dense_call)
187-

tests/checkers/test_torch/test_torch_parameter.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def test_sgd_params(self):
1919
sgd_call = node.value
2020

2121
with self.assertAddsMessages(
22-
pylint.testutils.MessageTest(
23-
msg_id="pytorch-parameter",
24-
confidence=HIGH,
25-
node=sgd_call,
26-
args=("lr", "SGD"), # Specify the expected missing parameters and method name
27-
),
28-
ignore_position=True,
22+
pylint.testutils.MessageTest(
23+
msg_id="pytorch-parameter",
24+
confidence=HIGH,
25+
node=sgd_call,
26+
args=("lr", "SGD"), # Specify the expected missing parameters and method name
27+
),
28+
ignore_position=True,
2929
):
3030
self.checker.visit_call(sgd_call)
3131

@@ -53,13 +53,13 @@ def test_adam_params(self):
5353
adam_call = node.value
5454

5555
with self.assertAddsMessages(
56-
pylint.testutils.MessageTest(
57-
msg_id="pytorch-parameter",
58-
confidence=HIGH,
59-
node=adam_call,
60-
args=("lr", "Adam"),
61-
),
62-
ignore_position=True,
56+
pylint.testutils.MessageTest(
57+
msg_id="pytorch-parameter",
58+
confidence=HIGH,
59+
node=adam_call,
60+
args=("lr", "Adam"),
61+
),
62+
ignore_position=True,
6363
):
6464
self.checker.visit_call(adam_call)
6565

@@ -87,13 +87,13 @@ def test_conv2d_params(self):
8787
conv2d_call = node.value
8888

8989
with self.assertAddsMessages(
90-
pylint.testutils.MessageTest(
91-
msg_id="pytorch-parameter",
92-
confidence=HIGH,
93-
node=conv2d_call,
94-
args=("out_channels", "Conv2d"),
95-
),
96-
ignore_position=True,
90+
pylint.testutils.MessageTest(
91+
msg_id="pytorch-parameter",
92+
confidence=HIGH,
93+
node=conv2d_call,
94+
args=("out_channels", "Conv2d"),
95+
),
96+
ignore_position=True,
9797
):
9898
self.checker.visit_call(conv2d_call)
9999

@@ -121,13 +121,13 @@ def test_linear_params(self):
121121
linear_call = node.value
122122

123123
with self.assertAddsMessages(
124-
pylint.testutils.MessageTest(
125-
msg_id="pytorch-parameter",
126-
confidence=HIGH,
127-
node=linear_call,
128-
args=("out_features", "Linear"),
129-
),
130-
ignore_position=True,
124+
pylint.testutils.MessageTest(
125+
msg_id="pytorch-parameter",
126+
confidence=HIGH,
127+
node=linear_call,
128+
args=("out_features", "Linear"),
129+
),
130+
ignore_position=True,
131131
):
132132
self.checker.visit_call(linear_call)
133133

@@ -155,13 +155,13 @@ def test_lstm_params(self):
155155
lstm_call = node.value
156156

157157
with self.assertAddsMessages(
158-
pylint.testutils.MessageTest(
159-
msg_id="pytorch-parameter",
160-
confidence=HIGH,
161-
node=lstm_call,
162-
args=("hidden_size", "LSTM"),
163-
),
164-
ignore_position=True,
158+
pylint.testutils.MessageTest(
159+
msg_id="pytorch-parameter",
160+
confidence=HIGH,
161+
node=lstm_call,
162+
args=("hidden_size", "LSTM"),
163+
),
164+
ignore_position=True,
165165
):
166166
self.checker.visit_call(lstm_call)
167167

0 commit comments

Comments
 (0)