Skip to content

Commit 765f00d

Browse files
dibussocericwb
andauthored
Limit B614 to torch.load deserializers (#1348)
* Limit B614 to torch.load deserializers Avoids false positives for torch.*.load helpers such as torch.utils.cpp_extension.load while preserving checks for torch.load and torch.serialization.load. Updated docstrings and example to reflect expected behavior. Resolves: #1343 * Update examples/pytorch_load.py --------- Co-authored-by: Eric Brown <ericwb@users.noreply.github.com>
1 parent 06fbbab commit 765f00d

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

bandit/plugins/pytorch_load.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
B614: Test for unsafe PyTorch load
77
==================================
88
9-
This plugin checks for unsafe use of `torch.load`. Using `torch.load` with
10-
untrusted data can lead to arbitrary code execution. There are two safe
11-
alternatives:
9+
This plugin checks for unsafe use of `torch.load` and
10+
`torch.serialization.load`. Using `torch.load` or
11+
`torch.serialization.load` with untrusted data can lead to arbitrary
12+
code execution. There are two safe alternatives:
1213
1314
1. Use `torch.load` with `weights_only=True` where only tensor data is
1415
extracted, and no arbitrary Python objects are deserialized
@@ -24,7 +25,7 @@
2425
2526
>> Issue: Use of unsafe PyTorch load
2627
Severity: Medium Confidence: High
27-
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
28+
CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
2829
Location: examples/pytorch_load_save.py:8
2930
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
3031
8 another_model.load_state_dict(torch.load('model_weights.pth',
@@ -34,7 +35,7 @@
3435
3536
.. seealso::
3637
37-
- https://cwe.mitre.org/data/definitions/94.html
38+
- https://cwe.mitre.org/data/definitions/502.html
3839
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
3940
- https://github.com/huggingface/safetensors
4041
@@ -50,23 +51,18 @@
5051
@test.test_id("B614")
5152
def pytorch_load(context):
5253
"""
53-
This plugin checks for unsafe use of `torch.load`. Using `torch.load`
54-
with untrusted data can lead to arbitrary code execution. The safe
55-
alternative is to use `weights_only=True` or the safetensors library.
54+
This plugin checks for unsafe use of `torch.load` and
55+
`torch.serialization.load`. Using `torch.load` or
56+
`torch.serialization.load` with untrusted data can lead to
57+
arbitrary code execution. The safe alternative is to use
58+
`weights_only=True` or the safetensors library.
5659
"""
5760
imported = context.is_module_imported_exact("torch")
5861
qualname = context.call_function_name_qual
5962
if not imported and isinstance(qualname, str):
6063
return
6164

62-
qualname_list = qualname.split(".")
63-
func = qualname_list[-1]
64-
if all(
65-
[
66-
"torch" in qualname_list,
67-
func == "load",
68-
]
69-
):
65+
if qualname in {"torch.load", "torch.serialization.load"}:
7066
# For torch.load, check if weights_only=True is specified
7167
weights_only = context.get_call_arg_value("weights_only")
7268
if weights_only == "True" or weights_only is True:

examples/pytorch_load.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@
2424
# Example of loading with both map_location and weights_only=True (should NOT trigger B614)
2525
safe_cpu_model = models.resnet18()
2626
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True))
27+
28+
# Example of a torch.*.load call that should NOT trigger B614
29+
# Only pickle deserializers should trigger B614
30+
torch.utils.cpp_extension.load(name="example_ext", sources=[])

0 commit comments

Comments
 (0)