Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions ai_edge_torch/generative/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,7 @@ def random(
LoRA weights with random values.
"""
return cls._from_tensor_generator(
tensor_generator=lambda shape, dtype: torch.randint(
low=0, high=128, size=shape, dtype=dtype
),
tensor_generator=lambda shape, dtype: torch.rand(shape, dtype=dtype),
rank=rank,
config=config,
dtype=dtype,
Expand Down
18 changes: 12 additions & 6 deletions ai_edge_torch/generative/utilities/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ def get_custom_loader(
if checkpoint_format == "safetensors":
return load_file
if checkpoint_format == "pt":
return lambda path: torch.load(path, weights_only=True)
return lambda path: torch.load(
path, weights_only=True, map_location=torch.device("cpu")
)
raise ValueError(f"Unsupported checkpoint format: {checkpoint_format}")

if os.path.splitext(checkpoint_path)[1] in [".bin", ".pt", ".ckpt"]:
return lambda path: torch.load(path, weights_only=True)
return lambda path: torch.load(
path, weights_only=True, map_location=torch.device("cpu")
)
if checkpoint_path.endswith(".safetensors"):
return load_file
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
Expand Down Expand Up @@ -126,7 +130,7 @@ def load_pytorch_statedict(full_path: str):
patterns = []
if os.path.isdir(full_path):
patterns.append(os.path.join(full_path, "*.bin"))
patterns.append(os.path.join(full_path, "*pt"))
patterns.append(os.path.join(full_path, "*.pt"))
else:
patterns.append(full_path)
for pattern in patterns:
Expand All @@ -135,7 +139,9 @@ def load_pytorch_statedict(full_path: str):

tensors = {}
for file in files:
this_file_tensors = torch.load(file)
this_file_tensors = torch.load(
file, map_location=torch.device("cpu"), weights_only=True
)
for k in this_file_tensors:
assert k not in tensors
tensors.update(this_file_tensors)
Expand Down Expand Up @@ -279,14 +285,14 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
return load_safetensors
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
os.path.join(self._file_name, "*pt")
os.path.join(self._file_name, "*.pt")
):
return load_pytorch_statedict

if self._file_name.endswith(".safetensors"):
return load_safetensors

if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
return load_pytorch_statedict

raise ValueError("File format not supported.")
Expand Down
Loading