Skip to content

Commit 9a8860c

Browse files
authored
convert : use all parts in safetensors index (ggml-org#17286)
1 parent 9d3ef48 commit 9a8860c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
189189
return tensors
190190

191191
prefix = "model" if not self.is_mistral_format else "consolidated"
192-
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
192+
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
193193
is_safetensors: bool = len(part_names) > 0
194194
if not is_safetensors:
195-
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
195+
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
196196

197197
tensor_names_from_index: set[str] = set()
198198

@@ -209,6 +209,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
209209
if weight_map is None or not isinstance(weight_map, dict):
210210
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
211211
tensor_names_from_index.update(weight_map.keys())
212+
part_names |= set(weight_map.values())
212213
else:
213214
weight_map = {}
214215
else:

0 commit comments

Comments
 (0)