Skip to content

Commit 41749af

Browse files
authored
SONARPY-1914 Usage of "torch.load" can lead to untrusted code execution (#1956)
* SONARPY-1914 update metadata for S6985 * SONARPY-1914 implement S6985: Usage of "torch.load" can lead to untrusted code execution * SONARPY-1914 add S6985 to checklist * SONARPY-1914 add expected ruling issues * SONARPY-1914 implement to not to raise when weights_only is not false * SONARPY-1914 fix small issues of the PR * SONARPY-1914 Find variable definition and check if False When a variable is passed as the weights_only parameter, the `ReachingDefinitionsAnalysis` class is used to check if the variable is set to true * SONARPY-1914 add unit tests for TreeUtils.toStreamInstanceOfMapper
1 parent 5192d90 commit 41749af

File tree

10 files changed

+256
-3
lines changed

10 files changed

+256
-3
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"project:pecos/examples/MACLR/dataset.py": [
3+
116
4+
],
5+
"project:pecos/examples/MACLR/evaluate.py": [
6+
92
7+
],
8+
"project:pecos/examples/MACLR/main.py": [
9+
280
10+
],
11+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-arxiv/mlp.py": [
12+
109
13+
],
14+
"project:pecos/examples/giant-xrt/OGB_baselines/ogbn-papers100M/mlp_sgc.py": [
15+
109
16+
],
17+
"project:pecos/pecos/xmc/xtransformer/matcher.py": [
18+
411,
19+
1323,
20+
1338
21+
],
22+
"project:pecos/pecos/xmc/xtransformer/module.py": [
23+
460
24+
]
25+
}

python-checks/src/main/java/org/sonar/python/checks/CheckList.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ public static Iterable<Class> getChecks() {
366366
TooManyParametersCheck.class,
367367
TooManyReturnsCheck.class,
368368
TorchAutogradVariableShouldNotBeUsedCheck.class,
369+
TorchLoadLeadsToUntrustedCodeExecutionCheck.class,
369370
TrailingCommentCheck.class,
370371
TrailingWhitespaceCheck.class,
371372
TypeAliasAnnotationCheck.class,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import java.util.List;
23+
import java.util.Set;
24+
import org.sonar.check.Rule;
25+
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
26+
import org.sonar.plugins.python.api.symbols.Symbol;
27+
import org.sonar.plugins.python.api.tree.Argument;
28+
import org.sonar.plugins.python.api.tree.CallExpression;
29+
import org.sonar.plugins.python.api.tree.Expression;
30+
import org.sonar.plugins.python.api.tree.Name;
31+
import org.sonar.plugins.python.api.tree.RegularArgument;
32+
import org.sonar.plugins.python.api.tree.Tree;
33+
import org.sonar.python.cfg.fixpoint.ReachingDefinitionsAnalysis;
34+
import org.sonar.python.tree.TreeUtils;
35+
36+
@Rule(key = "S6985")
37+
public class TorchLoadLeadsToUntrustedCodeExecutionCheck extends PythonSubscriptionCheck {
38+
39+
public static final String TORCH_LOAD = "torch.load";
40+
public static final String MESSAGE = "Replace this call with a safe alternative.";
41+
public static final String PYTHON_FALSE = "False";
42+
public static final String WEIGHTS_ONLY = "weights_only";
43+
44+
private ReachingDefinitionsAnalysis reachingDefinitionsAnalysis;
45+
46+
@Override
47+
public void initialize(Context context) {
48+
context.registerSyntaxNodeConsumer(Tree.Kind.FILE_INPUT, ctx -> reachingDefinitionsAnalysis =
49+
new ReachingDefinitionsAnalysis(ctx.pythonFile()));
50+
51+
context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, ctx -> {
52+
CallExpression callExpression = (CallExpression) ctx.syntaxNode();
53+
Symbol calleeSymbol = callExpression.calleeSymbol();
54+
if (calleeSymbol != null && TORCH_LOAD.equals(calleeSymbol.fullyQualifiedName()) && isWeightsOnlyNotFoundOrSetToFalse(callExpression.arguments())) {
55+
ctx.addIssue(callExpression.callee(), MESSAGE);
56+
}
57+
});
58+
}
59+
60+
private boolean isWeightsOnlyNotFoundOrSetToFalse(List<Argument> arguments) {
61+
RegularArgument weightsOnlyArg = TreeUtils.argumentByKeyword(WEIGHTS_ONLY, arguments);
62+
if (weightsOnlyArg == null) return true;
63+
if (weightsOnlyArg.expression() instanceof Name name) {
64+
return PYTHON_FALSE.equals(name.name()) || isNameSetToFalse(name);
65+
}
66+
return false;
67+
}
68+
69+
private boolean isNameSetToFalse(Name name) {
70+
Set<Expression> values = reachingDefinitionsAnalysis.valuesAtLocation(name);
71+
return values.size() == 1 && values.stream()
72+
.flatMap(TreeUtils.toStreamInstanceOfMapper(Name.class))
73+
.map(Name::name).allMatch(PYTHON_FALSE::equals);
74+
}
75+
76+
77+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<p>This rule raises an issue when <code>pytorch.load</code> is used to load a model.</p>
2+
<h2>Why is this an issue?</h2>
3+
<p>In PyTorch, it is common to load serialized models using the <code>torch.load</code> function. Under the hood, <code>torch.load</code> uses the
4+
<code>pickle</code> library to load the model and the weights. If the model comes from an untrusted source, an attacker could inject a malicious
5+
payload which would be executed during the deserialization.</p>
6+
<h2>How to fix it</h2>
7+
<p>Use a safer alternative to load the model, such as <code>safetensors.torch.load_model</code>.</p>
8+
<h3>Code examples</h3>
9+
<h4>Noncompliant code example</h4>
10+
<pre data-diff-id="1" data-diff-type="noncompliant">
11+
import torch
12+
13+
model = torch.load('model.pth') # Noncompliant: torch.load is used to load the model
14+
</pre>
15+
<h4>Compliant solution</h4>
16+
<pre data-diff-id="1" data-diff-type="compliant">
17+
import torch
18+
import safetensors
19+
20+
model = MyModel()
21+
safetensors.torch.load_model(model, 'model.pth')
22+
</pre>
23+
<h2>Resources</h2>
24+
<h3>Documentation</h3>
25+
<ul>
26+
<li> Pytorch documentation: <a href="https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model">Save/Load Entire
27+
Model</a> </li>
28+
</ul>
29+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"title": "Usage of \"torch.load\" can lead to untrusted code execution",
3+
"type": "SECURITY_HOTSPOT",
4+
"status": "ready",
5+
"remediation": {
6+
"func": "Constant\/Issue",
7+
"constantCost": "15min"
8+
},
9+
"tags": [
10+
"pytorch",
11+
"machine-learning"
12+
],
13+
"defaultSeverity": "Major",
14+
"ruleSpecification": "RSPEC-6985",
15+
"sqKey": "S6985",
16+
"scope": "All",
17+
"quickfix": "infeasible",
18+
"code": {
19+
"impacts": {
20+
"SECURITY": "HIGH"
21+
},
22+
"attribute": "CONVENTIONAL"
23+
}
24+
}

python-checks/src/main/resources/org/sonar/l10n/py/rules/python/Sonar_way_profile.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@
247247
"S6973",
248248
"S6974",
249249
"S6979",
250-
"S6983"
250+
"S6983",
251+
"S6985"
251252
]
252253
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* SonarQube Python Plugin
3+
* Copyright (C) 2011-2024 SonarSource SA
4+
* mailto:info AT sonarsource DOT com
5+
*
6+
* This program is free software; you can redistribute it and/or
7+
* modify it under the terms of the GNU Lesser General Public
8+
* License as published by the Free Software Foundation; either
9+
* version 3 of the License, or (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14+
* Lesser General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU Lesser General Public License
17+
* along with this program; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19+
*/
20+
package org.sonar.python.checks;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.sonar.python.checks.utils.PythonCheckVerifier;
24+
25+
class TorchLoadLeadsToUntrustedCodeExecutionCheckTest {
26+
27+
@Test
28+
void test() {
29+
PythonCheckVerifier.verify("src/test/resources/checks/torchLoadLeadsToUntrustedCodeExecution.py", new TorchLoadLeadsToUntrustedCodeExecutionCheck());
30+
}
31+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from torch import load
3+
import safetensors
4+
5+
model = torch.load('model.pth') # Noncompliant {{Replace this call with a safe alternative.}}
6+
#^^^^^^^^^^
7+
8+
some_path = ...
9+
model = load(some_path) # Noncompliant
10+
11+
torch.load(model2, 'model.pth', weights_only=False) #Noncompliant
12+
13+
model2 = ...
14+
safetensors.torch.load_model(model2, 'model.pth')
15+
16+
torch.load(model2, 'model.pth', weights_only=True)
17+
18+
def unknown_weights_only_value(some_value, some_func):
19+
torch.load(model2, 'model.pth', weights_only=some_value)
20+
torch.load(model2, 'model.pth', weights_only=some_func())
21+
22+
def conditional_weights_only(cond):
23+
weights_only = True
24+
if cond:
25+
weights_only = False
26+
27+
torch.load(model2, 'model.pth', weights_only=weights_only)
28+
29+
def only_one_definition():
30+
weights_only = True
31+
torch.load(model2, 'model.pth', weights_only=weights_only)
32+
33+
weights_only = False
34+
torch.load(model2, 'model.pth', weights_only=weights_only) #Noncompliant
35+
36+
# test if no issue is raised if there is no symbol for the callee
37+
something[42]()

python-frontend/src/main/java/org/sonar/python/tree/TreeUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ public static <T extends Tree> Optional<T> toOptionalInstanceOf(Class<T> castToC
461461
return Optional.ofNullable(tree).filter(castToClass::isInstance).map(castToClass::cast);
462462
}
463463

464+
public static <T extends Tree> Function<Tree, Stream<T>> toStreamInstanceOfMapper(Class<T> castToClass) {
465+
return tree -> toOptionalInstanceOf(castToClass, tree).map(Stream::of).orElse(Stream.empty());
466+
}
467+
464468
public static Optional<Tree> firstChild(Tree tree, Predicate<Tree> filter) {
465469
if (filter.test(tree)) {
466470
return Optional.of(tree);

python-frontend/src/test/java/org/sonar/python/tree/TreeUtilsTest.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
import java.util.Collections;
2626
import java.util.List;
2727
import java.util.Optional;
28-
import java.util.function.Function;
29-
import java.util.stream.Collectors;
28+
import java.util.stream.Stream;
3029
import org.junit.jupiter.api.Test;
3130
import org.sonar.plugins.python.api.symbols.Symbol;
3231
import org.sonar.plugins.python.api.tree.AnyParameter;
@@ -544,6 +543,31 @@ void test_toInstanceOfMapper() {
544543
assertThat(funcDefPresent).isFalse();
545544
}
546545

546+
@Test
547+
void test_toStreamInstanceOfMapper() {
548+
var fileInput = PythonTestUtils.parse(
549+
"class A:",
550+
" x = True",
551+
" def foo(self):",
552+
" def foo2(x, y): return x + y",
553+
" return foo2(1, 1)",
554+
" class B:",
555+
" def bar(self): pass"
556+
);
557+
Tree tree = PythonTestUtils.getFirstChild(fileInput, t -> t.is(Kind.CLASSDEF));
558+
559+
boolean classPresent = Stream.of(tree)
560+
.flatMap(TreeUtils.toStreamInstanceOfMapper(ClassDef.class))
561+
.count() > 0;
562+
563+
assertThat(classPresent).isTrue();
564+
565+
boolean funcDefPresent = Stream.of(tree)
566+
.flatMap(TreeUtils.toStreamInstanceOfMapper(FunctionDef.class))
567+
.count() > 0;
568+
569+
assertThat(funcDefPresent).isFalse();
570+
}
547571
@Test
548572
void test_findIndentationSize() {
549573
var fileInput = PythonTestUtils.parse("def foo():\n" +

0 commit comments

Comments
 (0)