Skip to content

Commit 45b288f

Browse files
Address security concerns in code
Signed-off-by: Keval Morabia <[email protected]>
1 parent d0b0c0f commit 45b288f

File tree

15 files changed

+141
-42
lines changed

15 files changed

+141
-42
lines changed

docs/source/guides/2_save_load.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu
129129
model = ...
130130
131131
# Restore the model architecture using the saved `modelopt_state`
132+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
132133
modelopt_state = torch.load("modelopt_state.pth", weights_only=False)
133134
model = mto.restore_from_modelopt_state(model, modelopt_state)
134135

examples/llm_qat/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_model(
5151

5252
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
5353
if hasattr(model, "peft_config"):
54+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
5455
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
5556
restore_from_modelopt_state(model, modelopt_state)
5657
print_rank_0("Restored modelopt state")

examples/llm_sparsity/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ python data_prep.py --save_path data
8484

8585
The following command demonstrates how to perform SAT on the Llama2-7B model on 8 GPUs.
8686
The model is finetuned on the [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset for 3 epochs.
87-
The input data is tokenized to a maximum length of 1024 tokens. The tokenized data is saved as a pickle file for faster data loading. The one-time process takes less than an hour to finish depending on the CPU. The resulting pickle file can be utilized for future training sessions.
87+
The input data is tokenized to a maximum length of 1024 tokens.
8888

8989
```sh
9090
bash launch_finetune.sh --model meta-llama/Llama-2-7b-hf \

examples/llm_sparsity/finetune.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import argparse
3333
import copy
3434
import os
35-
import pickle
3635
from collections.abc import Sequence
3736
from dataclasses import dataclass, field
3837

@@ -232,27 +231,17 @@ def __init__(
232231
):
233232
super().__init__()
234233

235-
pickle_name = f"dict_{split}_{tokenizer.model_max_length}.pickle"
236234
with training_args.main_process_first():
237-
if os.path.isfile(pickle_name):
238-
with open(pickle_name, "rb") as f:
239-
print_rank_0("Reuse pickled data")
240-
data_dict = pickle.load(f)
241-
else:
242-
print_rank_0("Loading data...")
243-
list_data_dict = utils.jload(data_path)
244-
245-
print_rank_0("Formatting inputs...")
246-
prompt_input = PROMPT_DICT["prompt_input"]
247-
sources = [prompt_input.format_map(example) for example in list_data_dict]
248-
targets = [
249-
f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict
250-
]
251-
252-
print_rank_0("Tokenizing inputs... This may take some time...")
253-
data_dict = preprocess(sources, targets, tokenizer)
254-
with open(pickle_name, "wb") as f:
255-
pickle.dump(data_dict, f, pickle.HIGHEST_PROTOCOL)
235+
print_rank_0("Loading data...")
236+
list_data_dict = utils.jload(data_path)
237+
238+
print_rank_0("Formatting inputs...")
239+
prompt_input = PROMPT_DICT["prompt_input"]
240+
sources = [prompt_input.format_map(example) for example in list_data_dict]
241+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
242+
243+
print_rank_0("Tokenizing inputs... This may take some time...")
244+
data_dict = preprocess(sources, targets, tokenizer)
256245

257246
self.input_ids = data_dict["input_ids"]
258247
self.labels = data_dict["labels"]

modelopt/onnx/quantization/__main__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def get_parser() -> argparse.ArgumentParser:
5252
type=str,
5353
help="Calibration data in npz/npy format. If None, random data for calibration will be used.",
5454
)
55+
group.add_argument(
56+
"--trust_calibration_data",
57+
action="store_true",
58+
help="If True, trust the calibration data and allow pickle deserialization.",
59+
)
5560
group.add_argument(
5661
"--calibration_cache_path",
5762
type=str,
@@ -263,10 +268,23 @@ def main():
263268
args = get_parser().parse_args()
264269
calibration_data = None
265270
if args.calibration_data_path:
266-
calibration_data = np.load(args.calibration_data_path, allow_pickle=True)
267-
if args.calibration_data_path.endswith(".npz"):
268-
# Convert the NpzFile object to a Python dictionary
269-
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
271+
# Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks
272+
try:
273+
calibration_data = np.load(
274+
args.calibration_data_path, allow_pickle=args.trust_calibration_data
275+
)
276+
if args.calibration_data_path.endswith(".npz"):
277+
# Convert the NpzFile object to a Python dictionary
278+
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
279+
except ValueError as e:
280+
if "allow_pickle" in str(e) and not args.trust_calibration_data:
281+
raise ValueError(
282+
"Calibration data file contains pickled objects which pose a security risk. "
283+
"For trusted sources, you may enable pickle deserialization by setting the "
284+
"--trust_calibration_data flag."
285+
) from e
286+
else:
287+
raise
270288

271289
quantize(
272290
args.onnx_path,

modelopt/torch/export/distribute.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def read_configs_and_weights_from_rank(
9191
raise ValueError("NFSWorkspace is not initialized!")
9292
state_path = self._get_state_path(target_rank)
9393
if state_path.exists():
94+
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
9495
state = torch.load(state_path, map_location="cpu", weights_only=False)
9596
return state["config"], state["weight"]
9697
else:

modelopt/torch/opt/conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
526526
model = ... # Create the model-like object
527527
528528
# Restore the previously saved modelopt state followed by model weights
529+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
529530
mto.restore_from_modelopt_state(
530531
model, torch.load("modelopt_state.pt", weights_only=False)
531532
) # Restore modelopt state

modelopt/torch/opt/plugins/huggingface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def new_init_fn(self, *args, **kwargs):
7979
modelopt_state_path = _get_modelopt_state_path(model_path)
8080
_original__init__(self, *args, **kwargs)
8181
if os.path.isfile(modelopt_state_path):
82+
# Security NOTE: weights_only=False is used on ModelOpt-generated state_dict, not on untrusted user input
8283
modelopt_state = torch.load(modelopt_state_path, map_location="cpu", weights_only=False)
8384
with extra_context() if extra_context else nullcontext():
8485
restore_from_modelopt_state(self, modelopt_state)

modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def restore_sharded_modelopt_state(
242242
return
243243

244244
# Loading the common modelopt_state (replicated on all ranks)
245+
# Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input
245246
common_modelopt_state = torch.load(
246247
modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False
247248
)

modelopt/torch/opt/plugins/megatron.py

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
"""Support quantization and save/resore for Megatron."""
1616

1717
import contextlib
18-
import pickle # nosec
18+
import io
1919
import types
2020
from typing import Any
21+
from warnings import warn
2122

2223
import megatron.core.transformer.mlp as megatron_mlp
2324
import regex as re
@@ -26,6 +27,74 @@
2627
from ..dynamic import DynamicModule
2728

2829

30+
def _convert_dtypes_to_strings(obj: Any) -> Any:
31+
"""Convert torch.dtype to strings for JSON-safe serialization."""
32+
if isinstance(obj, torch.dtype):
33+
return {"__dtype__": str(obj)}
34+
elif isinstance(obj, dict):
35+
return {k: _convert_dtypes_to_strings(v) for k, v in obj.items()}
36+
elif isinstance(obj, (list, tuple)):
37+
converted = [_convert_dtypes_to_strings(item) for item in obj]
38+
return {"__tuple__": converted} if isinstance(obj, tuple) else converted
39+
return obj
40+
41+
42+
def _restore_dtypes_from_strings(obj: Any) -> Any:
43+
"""Restore torch.dtype from string representations."""
44+
if isinstance(obj, dict):
45+
if "__dtype__" in obj:
46+
dtype_str = obj["__dtype__"].split(".")[-1]
47+
return getattr(torch, dtype_str)
48+
elif "__tuple__" in obj:
49+
return tuple(_restore_dtypes_from_strings(item) for item in obj["__tuple__"])
50+
return {k: _restore_dtypes_from_strings(v) for k, v in obj.items()}
51+
elif isinstance(obj, list):
52+
return [_restore_dtypes_from_strings(item) for item in obj]
53+
return obj
54+
55+
56+
def safe_serialize_state(extra_state: dict) -> torch.Tensor:
57+
"""Serialize extra_state safely without pickle.
58+
59+
Uses torch.save with weights_only=True for security.
60+
Raises TypeError if extra_state contains unsafe types.
61+
"""
62+
# Convert dtypes to strings for safe serialization
63+
safe_state = _convert_dtypes_to_strings(extra_state)
64+
65+
# Serialize using PyTorch with new zipfile format
66+
buffer = io.BytesIO()
67+
torch.save(safe_state, buffer, _use_new_zipfile_serialization=True)
68+
69+
return torch.frombuffer(bytearray(buffer.getvalue()), dtype=torch.uint8)
70+
71+
72+
def safe_deserialize_state(state: torch.Tensor) -> dict:
73+
"""Deserialize extra_state safely without pickle.
74+
75+
Attempts new safe format first, falls back to pickle with warning for
76+
backward compatibility with old checkpoints.
77+
"""
78+
buffer = state.detach().cpu().numpy().tobytes()
79+
80+
try:
81+
# Try new safe format (weights_only=True)
82+
extra_state = torch.load(io.BytesIO(buffer), weights_only=True)
83+
return _restore_dtypes_from_strings(extra_state)
84+
85+
except Exception:
86+
# Fall back to pickle for old checkpoints
87+
warn(
88+
"Loading checkpoint in legacy pickle format. This poses a security risk (RCE). "
89+
"Please re-save your checkpoint to use the new safe format. ",
90+
FutureWarning,
91+
stacklevel=2,
92+
)
93+
import pickle # nosec - backward compatibility only
94+
95+
return pickle.loads(buffer) # nosec
96+
97+
2998
def _modelopt_get_extra_state(self):
3099
"""Populating the extra_state when state_dict() is called.
31100
@@ -34,8 +103,7 @@ def _modelopt_get_extra_state(self):
34103
get_extra_state callbacks
35104
36105
If there is no extra_state, None is returned. Otherwise, the dictionary
37-
is serialized (via pickle) into a byte tensor following
38-
TransformerEngine's approach. In this case, the extra_state,
106+
is safely serialized into a byte tensor.
39107
"""
40108
try:
41109
extra_state = super().get_extra_state() # type: ignore[misc]
@@ -54,12 +122,9 @@ def _modelopt_get_extra_state(self):
54122
if len(extra_state) == 0:
55123
return None
56124

57-
# Serialize state into byte tensor
125+
# Serialize state safely without pickle
58126
torch.cuda.synchronize()
59-
state_serialized = bytearray(pickle.dumps(extra_state)) # nosec
60-
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
61-
62-
return state_serialized
127+
return safe_serialize_state(extra_state)
63128

64129

65130
def _modelopt_set_extra_state(self, state: Any):
@@ -73,11 +138,7 @@ def _modelopt_set_extra_state(self, state: Any):
73138
return
74139

75140
if isinstance(state, torch.Tensor):
76-
# Default format: byte tensor with pickled data
77-
#
78-
# TODO: possible deserialization improvement
79-
# https://github.com/NVIDIA/TensorRT-LLM/commits/main/tensorrt_llm/serialization.py
80-
extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec
141+
extra_state = safe_deserialize_state(state)
81142
else:
82143
raise RuntimeError("Unsupported extra_state format.")
83144

0 commit comments

Comments
 (0)