Skip to content

Commit 1c03c81

Browse files
Address security concerns in code
Signed-off-by: Keval Morabia <[email protected]>
1 parent d0b0c0f commit 1c03c81

File tree

15 files changed

+67
-330
lines changed

15 files changed

+67
-330
lines changed

docs/source/guides/2_save_load.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu
129129
model = ...
130130
131131
# Restore the model architecture using the saved `modelopt_state`
132+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
132133
modelopt_state = torch.load("modelopt_state.pth", weights_only=False)
133134
model = mto.restore_from_modelopt_state(model, modelopt_state)
134135

examples/llm_qat/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_model(
5151

5252
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
5353
if hasattr(model, "peft_config"):
54+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
5455
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
5556
restore_from_modelopt_state(model, modelopt_state)
5657
print_rank_0("Restored modelopt state")

examples/llm_sparsity/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ python data_prep.py --save_path data
8484

8585
The following command demonstrates how to perform SAT on the Llama2-7B model on 8 GPUs.
8686
The model is finetuned on the [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset for 3 epochs.
87-
The input data is tokenized to a maximum length of 1024 tokens. The tokenized data is saved as a pickle file for faster data loading. The one-time process takes less than an hour to finish depending on the CPU. The resulting pickle file can be utilized for future training sessions.
87+
The input data is tokenized to a maximum length of 1024 tokens.
8888

8989
```sh
9090
bash launch_finetune.sh --model meta-llama/Llama-2-7b-hf \

examples/llm_sparsity/finetune.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import argparse
3333
import copy
3434
import os
35-
import pickle
3635
from collections.abc import Sequence
3736
from dataclasses import dataclass, field
3837

@@ -232,27 +231,17 @@ def __init__(
232231
):
233232
super().__init__()
234233

235-
pickle_name = f"dict_{split}_{tokenizer.model_max_length}.pickle"
236234
with training_args.main_process_first():
237-
if os.path.isfile(pickle_name):
238-
with open(pickle_name, "rb") as f:
239-
print_rank_0("Reuse pickled data")
240-
data_dict = pickle.load(f)
241-
else:
242-
print_rank_0("Loading data...")
243-
list_data_dict = utils.jload(data_path)
244-
245-
print_rank_0("Formatting inputs...")
246-
prompt_input = PROMPT_DICT["prompt_input"]
247-
sources = [prompt_input.format_map(example) for example in list_data_dict]
248-
targets = [
249-
f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict
250-
]
251-
252-
print_rank_0("Tokenizing inputs... This may take some time...")
253-
data_dict = preprocess(sources, targets, tokenizer)
254-
with open(pickle_name, "wb") as f:
255-
pickle.dump(data_dict, f, pickle.HIGHEST_PROTOCOL)
235+
print_rank_0("Loading data...")
236+
list_data_dict = utils.jload(data_path)
237+
238+
print_rank_0("Formatting inputs...")
239+
prompt_input = PROMPT_DICT["prompt_input"]
240+
sources = [prompt_input.format_map(example) for example in list_data_dict]
241+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
242+
243+
print_rank_0("Tokenizing inputs... This may take some time...")
244+
data_dict = preprocess(sources, targets, tokenizer)
256245

257246
self.input_ids = data_dict["input_ids"]
258247
self.labels = data_dict["labels"]

modelopt/onnx/quantization/__main__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def get_parser() -> argparse.ArgumentParser:
5252
type=str,
5353
help="Calibration data in npz/npy format. If None, random data for calibration will be used.",
5454
)
55+
group.add_argument(
56+
"--trust_calibration_data",
57+
action="store_true",
58+
help="If True, trust the calibration data and allow pickle deserialization.",
59+
)
5560
group.add_argument(
5661
"--calibration_cache_path",
5762
type=str,
@@ -263,10 +268,23 @@ def main():
263268
args = get_parser().parse_args()
264269
calibration_data = None
265270
if args.calibration_data_path:
266-
calibration_data = np.load(args.calibration_data_path, allow_pickle=True)
267-
if args.calibration_data_path.endswith(".npz"):
268-
# Convert the NpzFile object to a Python dictionary
269-
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
271+
# Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks
272+
try:
273+
calibration_data = np.load(
274+
args.calibration_data_path, allow_pickle=args.trust_calibration_data
275+
)
276+
if args.calibration_data_path.endswith(".npz"):
277+
# Convert the NpzFile object to a Python dictionary
278+
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
279+
except ValueError as e:
280+
if "allow_pickle" in str(e) and not args.trust_calibration_data:
281+
raise ValueError(
282+
"Calibration data file contains pickled objects which pose a security risk. "
283+
"For trusted sources, you may enable pickle deserialization by setting the "
284+
"--trust_calibration_data flag."
285+
) from e
286+
else:
287+
raise
270288

271289
quantize(
272290
args.onnx_path,

0 commit comments

Comments
 (0)