Skip to content

Commit 5192d90

Browse files
authored
SONARPY-1907 implement Rule S6983 : The nb_workers parameter should be specified for torch.utils.data.DataLoader (#1955)
* SONARPY-1907 add metadata * SONARPY-1907 implement Rule S6983 : The nb_workers parameter should be specified for torch.utils.data.DataLoader * SONARPY-1907 update expected raised issues for S6983 * SONARPY-1907 small change in accordance with PR
1 parent 325e0ec commit 5192d90

File tree

8 files changed

+223
-1
lines changed

8 files changed

+223
-1
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"project:pecos/examples/MACLR/evaluate.py": [
3+
86,
4+
112
5+
],
6+
"project:pecos/examples/MACLR/main.py": [
7+
227,
8+
231,
9+
307
10+
],
11+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-papers100M/mlp_sgc.py": [
12+
118,
13+
119,
14+
120
15+
],
16+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-papers100M/mlp_xrt.py": [
17+
127,
18+
128,
19+
129
20+
]
21+
}

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ public static Iterable<Class> getChecks() {
311311
PublicApiIsSecuritySensitiveCheck.class,
312312
PubliclyWritableDirectoriesCheck.class,
313313
PublicNetworkAccessToCloudResourcesCheck.class,
314+
PyTorchDataLoaderNumWorkersCheck.class,
314315
PytzTimeZoneInDatetimeConstructorCheck.class,
315316
RaiseOutsideExceptCheck.class,
316317
RandomSeedCheck.class,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.List;
23+
import org.sonar.check.Rule;
24+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
25+
import org.sonar.plugins.python.api.symbols.Symbol;
26+
import org.sonar.plugins.python.api.tree.Argument;
27+
import org.sonar.plugins.python.api.tree.CallExpression;
28+
import org.sonar.plugins.python.api.tree.Tree;
29+
import org.sonar.plugins.python.api.tree.UnpackingExpression;
30+
import org.sonar.python.tree.TreeUtils;
31+
32+
@Rule(key = "S6983")
33+
public class PyTorchDataLoaderNumWorkersCheck extends PythonSubscriptionCheck {
34+
private static final String TORCH_UTILS_DATA_DATA_LOADER = "torch.utils.data.DataLoader";
35+
public static final String MESSAGE = "Specify the `num_workers` parameter.";
36+
public static final String NUM_WORKERS_ARG_NAME = "num_workers";
37+
public static final int NUM_WORKERS_ARG_POSITION = 5;
38+
39+
@Override
40+
public void initialize(Context context) {
41+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
42+
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
43+
Symbol calleeSymbol = callExpression.calleeSymbol();
44+
List<Argument> arguments = callExpression.arguments();
45+
if (calleeSymbol != null && TORCH_UTILS_DATA_DATA_LOADER.equals(calleeSymbol.fullyQualifiedName())
46+
&& isNumWorkersArgPresent(arguments)
47+
&& !isUnpackArgPresent(arguments)) {
48+
49+
ctx.addIssue(callExpression.callee(), MESSAGE);
50+
}
51+
});
52+
}
53+
54+
private static boolean isNumWorkersArgPresent(List<Argument> arguments) {
55+
return TreeUtils.nthArgumentOrKeyword(NUM_WORKERS_ARG_POSITION, NUM_WORKERS_ARG_NAME, arguments) == null;
56+
}
57+
58+
private static boolean isUnpackArgPresent(List<Argument> arguments) {
59+
return arguments.stream().anyMatch(UnpackingExpression.class::isInstance);
60+
}
61+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
<p>This rule raises an issue when a <code>torch.utils.data.Dataloader</code> is instantiated without specifying the <code>num_workers</code>
2+
parameter.</p>
3+
<h2>Why is this an issue?</h2>
4+
<p>In the PyTorch library, the data loaders are used to provide an interface where common operations such as batching can be implemented. It is also
5+
possible to parallelize the data loading process by using multiple worker processes. This can improve performance by increasing the number of batches
6+
being fetched in parallel, at the cost of higher memory usage. This performance increase can also be attributed to avoiding the Global Interpreter
7+
Lock (GIL) in the Python interpreter.</p>
8+
<h2>How to fix it</h2>
9+
<p>Specify the <code>num_workers</code> parameter when instantiating the <code>torch.utils.data.Dataloader</code> object.</p>
10+
<p>The default value of <code>0</code> will use the main process to load the data, and might be faster for small datasets that can fit completely in
11+
memory.</p>
12+
<p>For larger datasets, it is recommended to use a value of <code>1</code> or higher to parallelize the data loading process.</p>
13+
<h3>Code examples</h3>
14+
<h4>Noncompliant code example</h4>
15+
<pre data-diff-id="1" data-diff-type="noncompliant">
16+
from torch.utils.data import DataLoader
17+
from torchvision import datasets
18+
from torchvision.transforms import ToTensor
19+
20+
train_dataset = datasets.MNIST(root='data', train=True, transform=ToTensor())
21+
train_data_loader = DataLoader(train_dataset, batch_size=32)# Noncompliant: the num_workers parameter is not specified
22+
</pre>
23+
<h4>Compliant solution</h4>
24+
<pre data-diff-id="1" data-diff-type="compliant">
25+
from torch.utils.data import DataLoader
26+
from torchvision import datasets
27+
from torchvision.transforms import ToTensor
28+
29+
train_dataset = datasets.MNIST(root='data', train=True, transform=ToTensor())
30+
train_data_loader = DataLoader(train_dataset, batch_size=32, num_workers=4)
31+
</pre>
32+
<h2>Resources</h2>
33+
<h3>Documentation</h3>
34+
<ul>
35+
<li> PyTorch documentation - <a href="https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading">Single- and Multi-process
36+
Data Loading</a> </li>
37+
<li> PyTorch documentation - <a href="https://pytorch.org/tutorials/beginner/basics/data_tutorial.html">Datasets and DataLoaders</a> </li>
38+
</ul>
39+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"title": "The \"num_workers\" parameter should be specified for \"torch.utils.data.DataLoader\"",
3+
"type": "CODE_SMELL",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "2min"
8+
},
9+
"tags": [
10+
"pytorch",
11+
"machine-learning"
12+
],
13+
"defaultSeverity": "Minor",
14+
"ruleSpecification": "RSPEC-6983",
15+
"sqKey": "S6983",
16+
"scope": "All",
17+
"quickfix": "targeted",
18+
"code": {
19+
"impacts": {
20+
"RELIABILITY": "LOW"
21+
},
22+
"attribute": "COMPLETE"
23+
}
24+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
"S6972",
247247
"S6973",
248248
"S6974",
249-
"S6979"
249+
"S6979",
250+
"S6983"
250251
]
251252
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
23+
import org.junit.jupiter.api.Test;
24+
import org.sonar.python.checks.utils.PythonCheckVerifier;
25+
26+
class PyTorchDataLoaderNumWorkersCheckTest {
27+
@Test
28+
void test() {
29+
PythonCheckVerifier.verify("src/test/resources/checks/pyTorchDataLoaderNumWorkersCheck.py", new PyTorchDataLoaderNumWorkersCheck());
30+
}
31+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from torch.utils.data import DataLoader
2+
from torch.utils.data import DataLoader as AliasedDataLoader
3+
import torch.utils.data
4+
import os
5+
6+
train_dataset = ...
7+
8+
noncomp = DataLoader(dataset=train_dataset, batch_size=32) # Noncompliant {{Specify the `num_workers` parameter.}}
9+
#^^^^^^^^^^
10+
11+
noncomp = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32) # Noncompliant
12+
#^^^^^^^^^^^^^^^^^^^^^^^^^^^
13+
14+
noncomp = AliasedDataLoader(dataset=train_dataset, batch_size=32) # Noncompliant
15+
#^^^^^^^^^^^^^^^^^
16+
17+
noncomp = DataLoader() # Noncompliant
18+
19+
comp1 = DataLoader(dataset=train_dataset, batch_size=32, num_workers=len(train_dataset) / os.cpu_count())
20+
comp2 = DataLoader(dataset=train_dataset, batch_size=32, num_workers=0)
21+
comp3 = DataLoader(dataset=train_dataset, batch_size=32, num_workers=1)
22+
comp4 = DataLoader(train_dataset, 32, False, False, False, 3) # the num_workers is the 6th arg, and in this case `3`
23+
comp5 = DataLoader(train_dataset, 32, False, False, False, 3, False)
24+
25+
dict = {"someStuff":4}
26+
comp5 = DataLoader(**dict)
27+
comp6 = DataLoader(dataset=train_dataset, **dict)
28+
comp7 = DataLoader(**{"someStuff": 3})
29+
30+
list = [1, 2, 3, 4, 5, 6]
31+
comp8 = DataLoader(*list)
32+
comp8 = DataLoader(dataset=train_dataset, *list)
33+
comp9 = DataLoader(*[1, 2, 3])
34+
35+
comp10 = DataLoader(dataset=train_dataset, num_workers=None)
36+
37+
class SubDataLoader(DataLoader):
38+
pass
39+
40+
# this should raise an issue but this is currently not supported
41+
comp10 = SubDataLoader()
42+
43+
# checks coverage for if the symbol is null
44+
(lambda x: x)(2)

0 commit comments

Comments
 (0)