From 5432f60295d9cf1d37be48f04019a7fb00299db1 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 Date: Thu, 7 Aug 2025 22:28:42 +0000 Subject: [PATCH] fix(security): harden checkpoint loading and fix LoRA initialization bugs - loader: use precise .pt detection (glob: *.pt, suffix: .pt) - loader: enforce torch.load on CPU with weights_only for safety - lora: use torch.rand for float dtype in random() to avoid runtime error - code cleanup: remove unused code and simplify operations --- ai_edge_torch/generative/layers/lora.py | 4 +--- ai_edge_torch/generative/utilities/loader.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ai_edge_torch/generative/layers/lora.py b/ai_edge_torch/generative/layers/lora.py index 9c4dedbe..69a5057c 100644 --- a/ai_edge_torch/generative/layers/lora.py +++ b/ai_edge_torch/generative/layers/lora.py @@ -268,9 +268,7 @@ def random( LoRA weights with random values. """ return cls._from_tensor_generator( - tensor_generator=lambda shape, dtype: torch.randint( - low=0, high=128, size=shape, dtype=dtype - ), + tensor_generator=lambda shape, dtype: torch.rand(shape, dtype=dtype), rank=rank, config=config, dtype=dtype, diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index 63c177d8..00ce93a4 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -48,11 +48,15 @@ def get_custom_loader( if checkpoint_format == "safetensors": return load_file if checkpoint_format == "pt": - return lambda path: torch.load(path, weights_only=True) + return lambda path: torch.load( + path, weights_only=True, map_location=torch.device("cpu") + ) raise ValueError(f"Unsupported checkpoint format: {checkpoint_format}") if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]: - return lambda path: torch.load(path, weights_only=True) + return lambda path: torch.load( + path, weights_only=True, map_location=torch.device("cpu") + ) if checkpoint_path.endswith(".safetensors"): return load_file raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}") @@ -126,7 +130,7 @@ def load_pytorch_statedict(full_path: str): patterns = [] if os.path.isdir(full_path): patterns.append(os.path.join(full_path, "*.bin")) - patterns.append(os.path.join(full_path, "*pt")) + patterns.append(os.path.join(full_path, "*.pt")) else: patterns.append(full_path) for pattern in patterns: @@ -135,7 +139,9 @@ def load_pytorch_statedict(full_path: str): tensors = {} for file in files: - this_file_tensors = torch.load(file) + this_file_tensors = torch.load( + file, map_location=torch.device("cpu"), weights_only=True + ) for k in this_file_tensors: assert k not in tensors tensors.update(this_file_tensors) @@ -279,14 +285,14 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]: if glob.glob(os.path.join(self._file_name, "*.safetensors")): return load_safetensors if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob( - os.path.join(self._file_name, "*pt") + os.path.join(self._file_name, "*.pt") ): return load_pytorch_statedict if self._file_name.endswith(".safetensors"): return load_safetensors - if self._file_name.endswith(".bin") or self._file_name.endswith("pt"): + if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"): return load_pytorch_statedict raise ValueError("File format not supported.")