Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/guides/2_save_load.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu
model = ...
# Restore the model architecture using the saved `modelopt_state`
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
modelopt_state = torch.load("modelopt_state.pth", weights_only=False)
model = mto.restore_from_modelopt_state(model, modelopt_state)
Expand Down
1 change: 1 addition & 0 deletions examples/llm_qat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_model(

# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
if hasattr(model, "peft_config"):
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
restore_from_modelopt_state(model, modelopt_state)
print_rank_0("Restored modelopt state")
Expand Down
2 changes: 1 addition & 1 deletion examples/llm_sparsity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ python data_prep.py --save_path data

The following command demonstrates how to perform SAT on the Llama2-7B model on 8 GPUs.
The model is finetuned on the [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset for 3 epochs.
The input data is tokenized to a maximum length of 1024 tokens. The tokenized data is saved as a pickle file for faster data loading. The one-time process takes less than an hour to finish depending on the CPU. The resulting pickle file can be utilized for future training sessions.
The input data is tokenized to a maximum length of 1024 tokens.

```sh
bash launch_finetune.sh --model meta-llama/Llama-2-7b-hf \
Expand Down
31 changes: 10 additions & 21 deletions examples/llm_sparsity/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import argparse
import copy
import os
import pickle
from collections.abc import Sequence
from dataclasses import dataclass, field

Expand Down Expand Up @@ -232,27 +231,17 @@ def __init__(
):
super().__init__()

pickle_name = f"dict_{split}_{tokenizer.model_max_length}.pickle"
with training_args.main_process_first():
if os.path.isfile(pickle_name):
with open(pickle_name, "rb") as f:
print_rank_0("Reuse pickled data")
data_dict = pickle.load(f)
else:
print_rank_0("Loading data...")
list_data_dict = utils.jload(data_path)

print_rank_0("Formatting inputs...")
prompt_input = PROMPT_DICT["prompt_input"]
sources = [prompt_input.format_map(example) for example in list_data_dict]
targets = [
f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict
]

print_rank_0("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)
with open(pickle_name, "wb") as f:
pickle.dump(data_dict, f, pickle.HIGHEST_PROTOCOL)
print_rank_0("Loading data...")
list_data_dict = utils.jload(data_path)

print_rank_0("Formatting inputs...")
prompt_input = PROMPT_DICT["prompt_input"]
sources = [prompt_input.format_map(example) for example in list_data_dict]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

print_rank_0("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
Expand Down
26 changes: 22 additions & 4 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def get_parser() -> argparse.ArgumentParser:
type=str,
help="Calibration data in npz/npy format. If None, random data for calibration will be used.",
)
group.add_argument(
"--trust_calibration_data",
action="store_true",
help="If True, trust the calibration data and allow pickle deserialization.",
)
group.add_argument(
"--calibration_cache_path",
type=str,
Expand Down Expand Up @@ -263,10 +268,23 @@ def main():
args = get_parser().parse_args()
calibration_data = None
if args.calibration_data_path:
calibration_data = np.load(args.calibration_data_path, allow_pickle=True)
if args.calibration_data_path.endswith(".npz"):
# Convert the NpzFile object to a Python dictionary
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
# Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks
try:
calibration_data = np.load(
args.calibration_data_path, allow_pickle=args.trust_calibration_data
)
if args.calibration_data_path.endswith(".npz"):
# Convert the NpzFile object to a Python dictionary
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
except ValueError as e:
if "allow_pickle" in str(e) and not args.trust_calibration_data:
raise ValueError(
"Calibration data file contains pickled objects which pose a security risk. "
"For trusted sources, you may enable pickle deserialization by setting the "
"--trust_calibration_data flag."
) from e
else:
raise

quantize(
args.onnx_path,
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def read_configs_and_weights_from_rank(
raise ValueError("NFSWorkspace is not initialized!")
state_path = self._get_state_path(target_rank)
if state_path.exists():
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state = torch.load(state_path, map_location="cpu", weights_only=False)
return state["config"], state["weight"]
else:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/opt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
model = ... # Create the model-like object

# Restore the previously saved modelopt state followed by model weights
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
mto.restore_from_modelopt_state(
model, torch.load("modelopt_state.pt", weights_only=False)
) # Restore modelopt state
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/opt/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def new_init_fn(self, *args, **kwargs):
modelopt_state_path = _get_modelopt_state_path(model_path)
_original__init__(self, *args, **kwargs)
if os.path.isfile(modelopt_state_path):
# Security NOTE: weights_only=False is used on ModelOpt-generated state_dict, not on untrusted user input
modelopt_state = torch.load(modelopt_state_path, map_location="cpu", weights_only=False)
with extra_context() if extra_context else nullcontext():
restore_from_modelopt_state(self, modelopt_state)
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def restore_sharded_modelopt_state(
return

# Loading the common modelopt_state (replicated on all ranks)
# Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input
common_modelopt_state = torch.load(
modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False
)
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/opt/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _modelopt_set_extra_state(self, state: Any):
# Default format: byte tensor with pickled data
#
# TODO: possible deserialization improvement
# https://github.com/NVIDIA/TensorRT-LLM/commits/main/tensorrt_llm/serialization.py
# https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serialization.py
extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec
else:
raise RuntimeError("Unsupported extra_state format.")
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/opt/plugins/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs):
assert adapter_name in self.peft_config, (
f"ModelOpt modified model should have adapter_name={adapter_name} in peft_config"
)
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
restore_from_modelopt_state(
self, torch.load(modelopt_state_path, map_location="cpu", weights_only=False)
)
Expand All @@ -85,6 +86,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs):
if os.path.isfile(_get_quantizer_state_save_path(model_id)):
from modelopt.torch.quantization.nn import TensorQuantizer

# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
quantizer_state_dict = torch.load(
_get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False
)
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/opt/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def load_search_checkpoint(self) -> bool:

# iterate through state dict and load keys
print(f"Loading searcher state from {checkpoint}...")
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state_dict = torch.load(checkpoint, weights_only=False)
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
for key, state in state_dict.items():
Expand Down
26 changes: 23 additions & 3 deletions modelopt/torch/quantization/plugins/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,24 @@ def patch_binop(node, quantizer_names, transpose=False):
head = ast.fix_missing_locations(head)
org_class = model_module.__dict__[org_class_name]

quant_class = _create_quantized_class_from_ast(head, org_class, new_class_name, model_module)
quant_class = _create_quantized_class_from_ast(head, org_class, new_class_name)
register(original_cls=org_class, quantized_cls=quant_class)
print(f"Successfully registered {org_class_name} for quantization")
return True


def _create_quantized_class_from_ast(
head, org_class, new_class_name, model_module, temp_file_name=None
head: ast.Module,
org_class: type,
new_class_name: str,
temp_file_name: str | None = None,
):
"""Create a quantized class from an AST representation.

Args:
head: The AST head containing the modified class definition
org_class: The original class to be quantized
new_class_name: Name for the new quantized class
model_module: The module containing the original class
temp_file_name: Optional file name to save the generated code

Returns:
Expand All @@ -232,6 +234,19 @@ def _create_quantized_class_from_ast(

# Save the generated code to a temporary file if requested
module_code_str = ast.unparse(head)

# Security: Validate generated code doesn't contain suspicious patterns
suspicious_patterns = ["__import__", "eval", "exec", "compile", "open(", "os.system"]
for pattern in suspicious_patterns:
if pattern in module_code_str:
# Allow compile for specific trusted ModelOpt internal use
if pattern == "compile" and "torch.compile" in module_code_str:
continue
raise ValueError(
f"Generated code contains suspicious pattern '{pattern}'. "
f"This may indicate a security issue in AST transformation."
)

if temp_file_name is None:
with tempfile.NamedTemporaryFile(
prefix="modelopt_", suffix=".py", delete=False
Expand All @@ -253,6 +268,11 @@ def _create_quantized_class_from_ast(
# ) # bandit throws error here
# quant_class = model_module.__dict__[new_class_name]

# Security NOTE: compile() is used here on internally-generated AST,
# not on untrusted user input. The AST is created by ModelOpt's quantization
# logic and has been validated above. This is safer than exec() but still
# requires the AST transformation logic to be secure.

# Extract the bytecode and create a new class on the fly
# This is more tricky but doesn't require runtime execution
module_code = compile(head, filename=f"{temp_file_name}", mode="exec")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _save_modelopt_state_with_weights(self):
print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}")

def _restore_modelopt_state_with_weights(self):
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
modelopt_state = torch.load(self._modelopt_state_path, weights_only=False)
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
restore_from_modelopt_state(self.model, modelopt_state)
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _deserialize(tensor: torch.Tensor, size: int | None = None) -> Any:
buffer = tensor.numpy().tobytes()
if size is not None:
buffer = buffer[:size]
# Security NOTE: weights_only=False is used here on internally-generated buffer, not on untrusted user input
obj = torch.load(io.BytesIO(buffer), weights_only=False)
return obj

Expand Down