diff --git a/docs/source/guides/2_save_load.rst b/docs/source/guides/2_save_load.rst index d0c0b8cb8..e097e3f80 100644 --- a/docs/source/guides/2_save_load.rst +++ b/docs/source/guides/2_save_load.rst @@ -129,6 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu model = ... # Restore the model architecture using the saved `modelopt_state` + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load("modelopt_state.pth", weights_only=False) model = mto.restore_from_modelopt_state(model, modelopt_state) diff --git a/examples/llm_qat/export.py b/examples/llm_qat/export.py index 77d75d47a..7954f8eac 100644 --- a/examples/llm_qat/export.py +++ b/examples/llm_qat/export.py @@ -51,6 +51,7 @@ def get_model( # Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this if hasattr(model, "peft_config"): + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) restore_from_modelopt_state(model, modelopt_state) print_rank_0("Restored modelopt state") diff --git a/examples/llm_sparsity/README.md b/examples/llm_sparsity/README.md index e7b8b30e0..541ba1ac2 100644 --- a/examples/llm_sparsity/README.md +++ b/examples/llm_sparsity/README.md @@ -84,7 +84,7 @@ python data_prep.py --save_path data The following command demonstrates how to perform SAT on the Llama2-7B model on 8 GPUs. The model is finetuned on the [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset for 3 epochs. -The input data is tokenized to a maximum length of 1024 tokens. The tokenized data is saved as a pickle file for faster data loading. The one-time process takes less than an hour to finish depending on the CPU. The resulting pickle file can be utilized for future training sessions. +The input data is tokenized to a maximum length of 1024 tokens. ```sh bash launch_finetune.sh --model meta-llama/Llama-2-7b-hf \ diff --git a/examples/llm_sparsity/finetune.py b/examples/llm_sparsity/finetune.py index 3cfc1073f..573cae765 100644 --- a/examples/llm_sparsity/finetune.py +++ b/examples/llm_sparsity/finetune.py @@ -32,7 +32,6 @@ import argparse import copy import os -import pickle from collections.abc import Sequence from dataclasses import dataclass, field @@ -232,27 +231,17 @@ def __init__( ): super().__init__() - pickle_name = f"dict_{split}_{tokenizer.model_max_length}.pickle" with training_args.main_process_first(): - if os.path.isfile(pickle_name): - with open(pickle_name, "rb") as f: - print_rank_0("Reuse pickled data") - data_dict = pickle.load(f) - else: - print_rank_0("Loading data...") - list_data_dict = utils.jload(data_path) - - print_rank_0("Formatting inputs...") - prompt_input = PROMPT_DICT["prompt_input"] - sources = [prompt_input.format_map(example) for example in list_data_dict] - targets = [ - f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict - ] - - print_rank_0("Tokenizing inputs... This may take some time...") - data_dict = preprocess(sources, targets, tokenizer) - with open(pickle_name, "wb") as f: - pickle.dump(data_dict, f, pickle.HIGHEST_PROTOCOL) + print_rank_0("Loading data...") + list_data_dict = utils.jload(data_path) + + print_rank_0("Formatting inputs...") + prompt_input = PROMPT_DICT["prompt_input"] + sources = [prompt_input.format_map(example) for example in list_data_dict] + targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] + + print_rank_0("Tokenizing inputs... This may take some time...") + data_dict = preprocess(sources, targets, tokenizer) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index 55cca6ee5..26ed9e394 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -52,6 +52,11 @@ def get_parser() -> argparse.ArgumentParser: type=str, help="Calibration data in npz/npy format. If None, random data for calibration will be used.", ) + group.add_argument( + "--trust_calibration_data", + action="store_true", + help="If True, trust the calibration data and allow pickle deserialization.", + ) group.add_argument( "--calibration_cache_path", type=str, @@ -263,10 +268,23 @@ def main(): args = get_parser().parse_args() calibration_data = None if args.calibration_data_path: - calibration_data = np.load(args.calibration_data_path, allow_pickle=True) - if args.calibration_data_path.endswith(".npz"): - # Convert the NpzFile object to a Python dictionary - calibration_data = {key: calibration_data[key] for key in calibration_data.files} + # Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks + try: + calibration_data = np.load( + args.calibration_data_path, allow_pickle=args.trust_calibration_data + ) + if args.calibration_data_path.endswith(".npz"): + # Convert the NpzFile object to a Python dictionary + calibration_data = {key: calibration_data[key] for key in calibration_data.files} + except ValueError as e: + if "allow_pickle" in str(e) and not args.trust_calibration_data: + raise ValueError( + "Calibration data file contains pickled objects which pose a security risk. " + "For trusted sources, you may enable pickle deserialization by setting the " + "--trust_calibration_data flag." + ) from e + else: + raise quantize( args.onnx_path, diff --git a/modelopt/torch/export/distribute.py b/modelopt/torch/export/distribute.py index f9d902fd2..4fe7be43e 100644 --- a/modelopt/torch/export/distribute.py +++ b/modelopt/torch/export/distribute.py @@ -91,6 +91,7 @@ def read_configs_and_weights_from_rank( raise ValueError("NFSWorkspace is not initialized!") state_path = self._get_state_path(target_rank) if state_path.exists(): + # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state = torch.load(state_path, map_location="cpu", weights_only=False) return state["config"], state["weight"] else: diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 1de6143bd..874c51b59 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -526,6 +526,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any] model = ... # Create the model-like object # Restore the previously saved modelopt state followed by model weights + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input mto.restore_from_modelopt_state( model, torch.load("modelopt_state.pt", weights_only=False) ) # Restore modelopt state diff --git a/modelopt/torch/opt/plugins/huggingface.py b/modelopt/torch/opt/plugins/huggingface.py index 672d0f99a..99bab7725 100644 --- a/modelopt/torch/opt/plugins/huggingface.py +++ b/modelopt/torch/opt/plugins/huggingface.py @@ -79,6 +79,7 @@ def new_init_fn(self, *args, **kwargs): modelopt_state_path = _get_modelopt_state_path(model_path) _original__init__(self, *args, **kwargs) if os.path.isfile(modelopt_state_path): + # Security NOTE: weights_only=False is used on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load(modelopt_state_path, map_location="cpu", weights_only=False) with extra_context() if extra_context else nullcontext(): restore_from_modelopt_state(self, modelopt_state) diff --git a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py index f62183630..2467a946c 100644 --- a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py +++ b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py @@ -242,6 +242,7 @@ def restore_sharded_modelopt_state( return # Loading the common modelopt_state (replicated on all ranks) + # Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input common_modelopt_state = torch.load( modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False ) diff --git a/modelopt/torch/opt/plugins/megatron.py b/modelopt/torch/opt/plugins/megatron.py index 412f66f17..3a29a33f8 100644 --- a/modelopt/torch/opt/plugins/megatron.py +++ b/modelopt/torch/opt/plugins/megatron.py @@ -76,7 +76,7 @@ def _modelopt_set_extra_state(self, state: Any): # Default format: byte tensor with pickled data # # TODO: possible deserialization improvement - # https://github.com/NVIDIA/TensorRT-LLM/commits/main/tensorrt_llm/serialization.py + # https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serialization.py extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec else: raise RuntimeError("Unsupported extra_state format.") diff --git a/modelopt/torch/opt/plugins/peft.py b/modelopt/torch/opt/plugins/peft.py index 5e5ed0f93..c3fd268a5 100644 --- a/modelopt/torch/opt/plugins/peft.py +++ b/modelopt/torch/opt/plugins/peft.py @@ -72,6 +72,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs): assert adapter_name in self.peft_config, ( f"ModelOpt modified model should have adapter_name={adapter_name} in peft_config" ) + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input restore_from_modelopt_state( self, torch.load(modelopt_state_path, map_location="cpu", weights_only=False) ) @@ -85,6 +86,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs): if os.path.isfile(_get_quantizer_state_save_path(model_id)): from modelopt.torch.quantization.nn import TensorQuantizer + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input quantizer_state_dict = torch.load( _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False ) diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 5eb2e134e..3052289cb 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -249,6 +249,7 @@ def load_search_checkpoint(self) -> bool: # iterate through state dict and load keys print(f"Loading searcher state from {checkpoint}...") + # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state_dict = torch.load(checkpoint, weights_only=False) assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" for key, state in state_dict.items(): diff --git a/modelopt/torch/quantization/plugins/attention.py b/modelopt/torch/quantization/plugins/attention.py index 4761e9c35..e7069ee5c 100644 --- a/modelopt/torch/quantization/plugins/attention.py +++ b/modelopt/torch/quantization/plugins/attention.py @@ -207,14 +207,17 @@ def patch_binop(node, quantizer_names, transpose=False): head = ast.fix_missing_locations(head) org_class = model_module.__dict__[org_class_name] - quant_class = _create_quantized_class_from_ast(head, org_class, new_class_name, model_module) + quant_class = _create_quantized_class_from_ast(head, org_class, new_class_name) register(original_cls=org_class, quantized_cls=quant_class) print(f"Successfully registered {org_class_name} for quantization") return True def _create_quantized_class_from_ast( - head, org_class, new_class_name, model_module, temp_file_name=None + head: ast.Module, + org_class: type, + new_class_name: str, + temp_file_name: str | None = None, ): """Create a quantized class from an AST representation. @@ -222,7 +225,6 @@ def _create_quantized_class_from_ast( head: The AST head containing the modified class definition org_class: The original class to be quantized new_class_name: Name for the new quantized class - model_module: The module containing the original class temp_file_name: Optional file name to save the generated code Returns: @@ -232,6 +234,19 @@ def _create_quantized_class_from_ast( # Save the generated code to a temporary file if requested module_code_str = ast.unparse(head) + + # Security: Validate generated code doesn't contain suspicious patterns + suspicious_patterns = ["__import__", "eval", "exec", "compile", "open(", "os.system"] + for pattern in suspicious_patterns: + if pattern in module_code_str: + # Allow compile for specific trusted ModelOpt internal use + if pattern == "compile" and "torch.compile" in module_code_str: + continue + raise ValueError( + f"Generated code contains suspicious pattern '{pattern}'. " + f"This may indicate a security issue in AST transformation." + ) + if temp_file_name is None: with tempfile.NamedTemporaryFile( prefix="modelopt_", suffix=".py", delete=False @@ -253,6 +268,11 @@ def _create_quantized_class_from_ast( # ) # bandit throws error here # quant_class = model_module.__dict__[new_class_name] + # Security NOTE: compile() is used here on internally-generated AST, + # not on untrusted user input. The AST is created by ModelOpt's quantization + # logic and has been validated above. This is safer than exec() but still + # requires the AST transformation logic to be secure. + # Extract the bytecode and create a new class on the fly # This is more tricky but doesn't require runtime execution module_code = compile(head, filename=f"{temp_file_name}", mode="exec") diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index ad562dae3..a29dc850b 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -188,6 +188,7 @@ def _save_modelopt_state_with_weights(self): print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}") def _restore_modelopt_state_with_weights(self): + # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input modelopt_state = torch.load(self._modelopt_state_path, weights_only=False) modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) restore_from_modelopt_state(self.model, modelopt_state) diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 033b4aadb..171c85b46 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -87,6 +87,7 @@ def _deserialize(tensor: torch.Tensor, size: int | None = None) -> Any: buffer = tensor.numpy().tobytes() if size is not None: buffer = buffer[:size] + # Security NOTE: weights_only=False is used here on internally-generated buffer, not on untrusted user input obj = torch.load(io.BytesIO(buffer), weights_only=False) return obj