|
6 | 6 | B614: Test for unsafe PyTorch load |
7 | 7 | ================================== |
8 | 8 |
|
9 | | -This plugin checks for unsafe use of `torch.load`. Using `torch.load` with |
10 | | -untrusted data can lead to arbitrary code execution. There are two safe |
11 | | -alternatives: |
| 9 | +This plugin checks for unsafe use of `torch.load` and |
| 10 | +`torch.serialization.load`. Using `torch.load` or |
| 11 | +`torch.serialization.load` with untrusted data can lead to arbitrary |
| 12 | +code execution. There are two safe alternatives: |
12 | 13 |
|
13 | 14 | 1. Use `torch.load` with `weights_only=True` where only tensor data is |
14 | 15 | extracted, and no arbitrary Python objects are deserialized |
|
24 | 25 |
|
25 | 26 | >> Issue: Use of unsafe PyTorch load |
26 | 27 | Severity: Medium Confidence: High |
27 | | - CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html) |
| 28 | + CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html) |
28 | 29 | Location: examples/pytorch_load_save.py:8 |
29 | 30 | 7 loaded_model.load_state_dict(torch.load('model_weights.pth')) |
30 | 31 | 8 another_model.load_state_dict(torch.load('model_weights.pth', |
|
34 | 35 |
|
35 | 36 | .. seealso:: |
36 | 37 |
|
37 | | - - https://cwe.mitre.org/data/definitions/94.html |
| 38 | + - https://cwe.mitre.org/data/definitions/502.html |
38 | 39 | - https://pytorch.org/docs/stable/generated/torch.load.html#torch.load |
39 | 40 | - https://github.com/huggingface/safetensors |
40 | 41 |
|
|
50 | 51 | @test.test_id("B614") |
51 | 52 | def pytorch_load(context): |
52 | 53 | """ |
53 | | - This plugin checks for unsafe use of `torch.load`. Using `torch.load` |
54 | | - with untrusted data can lead to arbitrary code execution. The safe |
55 | | - alternative is to use `weights_only=True` or the safetensors library. |
| 54 | + This plugin checks for unsafe use of `torch.load` and |
| 55 | + `torch.serialization.load`. Using `torch.load` or |
| 56 | + `torch.serialization.load` with untrusted data can lead to |
| 57 | + arbitrary code execution. The safe alternative is to use |
| 58 | + `weights_only=True` or the safetensors library. |
56 | 59 | """ |
57 | 60 | imported = context.is_module_imported_exact("torch") |
58 | 61 | qualname = context.call_function_name_qual |
59 | 62 | if not imported and isinstance(qualname, str): |
60 | 63 | return |
61 | 64 |
|
62 | | - qualname_list = qualname.split(".") |
63 | | - func = qualname_list[-1] |
64 | | - if all( |
65 | | - [ |
66 | | - "torch" in qualname_list, |
67 | | - func == "load", |
68 | | - ] |
69 | | - ): |
| 65 | + if qualname in {"torch.load", "torch.serialization.load"}: |
70 | 66 | # For torch.load, check if weights_only=True is specified |
71 | 67 | weights_only = context.get_call_arg_value("weights_only") |
72 | 68 | if weights_only == "True" or weights_only is True: |
|
0 commit comments