Skip to content

Commit fec8b77

Browse files
committed
Changes from upstream: 9a8860c
1 parent b67e080 commit fec8b77

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

conversion/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
170170
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
171171
return tensors
172172
prefix = "model" if not self.is_mistral_format else "consolidated"
173-
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
173+
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
174174
is_safetensors: bool = len(part_names) > 0
175175
if not is_safetensors:
176-
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
176+
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
177177
tensor_names_from_index: set[str] = set()
178178
if not self.is_mistral_format:
179179
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
@@ -187,6 +187,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
187187
if weight_map is None or not isinstance(weight_map, dict):
188188
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
189189
tensor_names_from_index.update(weight_map.keys())
190+
part_names |= set(weight_map.values())
190191
else:
191192
weight_map = {}
192193
else:

0 commit comments

Comments
 (0)