Skip to content

Commit f9beb79

Browse files
committed
add argument to pass shared tensors keys to discard (#2696)
1 parent d7bead5 commit f9beb79

File tree

2 files changed

+99
-11
lines changed

2 files changed

+99
-11
lines changed

src/huggingface_hub/serialization/_torch.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def save_torch_model(
4141
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
4242
metadata: Optional[Dict[str, str]] = None,
4343
safe_serialization: bool = True,
44+
is_main_process: bool = True,
45+
shared_tensors_to_discard: Optional[List[str]] = None,
4446
):
4547
"""
4648
Saves a given torch model to disk, handling sharding and shared tensors issues.
@@ -64,6 +66,12 @@ def save_torch_model(
6466
6567
</Tip>
6668
69+
<Tip warning={true}>
70+
71+
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
72+
73+
</Tip>
74+
6775
Args:
6876
model (`torch.nn.Module`):
6977
The model to save on disk.
@@ -88,6 +96,13 @@ def save_torch_model(
8896
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
8997
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
9098
in a future version.
99+
is_main_process (`bool`, *optional*):
100+
Whether the process calling this is the main process or not. Useful when in distributed training like
101+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
102+
the main process to avoid race conditions. Defaults to True.
103+
shared_tensors_to_discard (`List[str]`, *optional*):
104+
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
105+
detected, it will drop the first name alphabetically.
91106
92107
Example:
93108
@@ -112,6 +127,8 @@ def save_torch_model(
112127
metadata=metadata,
113128
safe_serialization=safe_serialization,
114129
save_directory=save_directory,
130+
is_main_process=is_main_process,
131+
shared_tensors_to_discard=shared_tensors_to_discard,
115132
)
116133

117134

@@ -124,6 +141,8 @@ def save_torch_state_dict(
124141
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
125142
metadata: Optional[Dict[str, str]] = None,
126143
safe_serialization: bool = True,
144+
is_main_process: bool = True,
145+
shared_tensors_to_discard: Optional[List[str]] = None,
127146
) -> None:
128147
"""
129148
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
@@ -147,6 +166,12 @@ def save_torch_state_dict(
147166
148167
</Tip>
149168
169+
<Tip warning={true}>
170+
171+
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
172+
173+
</Tip>
174+
150175
Args:
151176
state_dict (`Dict[str, torch.Tensor]`):
152177
The state dictionary to save.
@@ -171,6 +196,13 @@ def save_torch_state_dict(
171196
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
172197
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
173198
in a future version.
199+
is_main_process (`bool`, *optional*):
200+
Whether the process calling this is the main process or not. Useful when in distributed training like
201+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
202+
the main process to avoid race conditions. Defaults to True.
203+
shared_tensors_to_discard (`List[str]`, *optional*):
204+
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
205+
detected, it will drop the first name alphabetically.
174206
175207
Example:
176208
@@ -192,7 +224,8 @@ def save_torch_state_dict(
192224
else constants.PYTORCH_WEIGHTS_FILE_PATTERN
193225
)
194226

195-
# Imports correct library
227+
if metadata is None:
228+
metadata = {}
196229
if safe_serialization:
197230
try:
198231
from safetensors.torch import save_file as save_file_fn
@@ -201,7 +234,13 @@ def save_torch_state_dict(
201234
"Please install `safetensors` to use safe serialization. "
202235
"You can install it with `pip install safetensors`."
203236
) from e
204-
237+
# Clean state dict for safetensors
238+
state_dict = _clean_state_dict_for_safetensors(
239+
state_dict,
240+
metadata,
241+
force_contiguous=force_contiguous,
242+
shared_tensors_to_discard=shared_tensors_to_discard,
243+
)
205244
else:
206245
from torch import save as save_file_fn # type: ignore[assignment]
207246

@@ -210,13 +249,6 @@ def save_torch_state_dict(
210249
"pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
211250
"using safe serialization by installing `safetensors` with `pip install safetensors`."
212251
)
213-
214-
# Clean state dict for safetensors
215-
if metadata is None:
216-
metadata = {}
217-
if safe_serialization:
218-
state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous)
219-
220252
# Split dict
221253
state_dict_split = split_torch_state_dict_into_shards(
222254
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
@@ -459,15 +491,18 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
459491

460492

461493
def _clean_state_dict_for_safetensors(
462-
state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True
494+
state_dict: Dict[str, "torch.Tensor"],
495+
metadata: Dict[str, str],
496+
force_contiguous: bool = True,
497+
shared_tensors_to_discard: Optional[List[str]] = None,
463498
):
464499
"""Remove shared tensors from state_dict and update metadata accordingly (for reloading).
465500
466501
Warning: `state_dict` and `metadata` are mutated in-place!
467502
468503
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
469504
"""
470-
to_removes = _remove_duplicate_names(state_dict)
505+
to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
471506
for kept_name, to_remove_group in to_removes.items():
472507
for to_remove in to_remove_group:
473508
if metadata is None:

tests/test_serialization.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
264264
max_shard_size="3GB",
265265
metadata={"foo": "bar"},
266266
safe_serialization=True,
267+
is_main_process=True,
268+
shared_tensors_to_discard=None,
267269
)
268270
safe_state_dict_mock.assert_called_once_with(
269271
state_dict=model_mock.state_dict.return_value,
@@ -273,6 +275,8 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
273275
max_shard_size="3GB",
274276
metadata={"foo": "bar"},
275277
safe_serialization=True,
278+
is_main_process=True,
279+
shared_tensors_to_discard=None,
276280
)
277281

278282

@@ -414,6 +418,55 @@ def test_save_torch_state_dict_shared_layers_sharded(
414418
assert "shared_2" not in state_dict
415419

416420

421+
def test_save_torch_state_dict_discard_selected_sharded(
422+
tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"]
423+
) -> None:
424+
from safetensors.torch import load_file
425+
426+
save_torch_state_dict(
427+
torch_state_dict_shared_layers,
428+
tmp_path,
429+
max_shard_size=2,
430+
safe_serialization=True,
431+
shared_tensors_to_discard=["shared_1"],
432+
)
433+
index_file = tmp_path / "model.safetensors.index.json"
434+
index = json.loads(index_file.read_text())
435+
436+
assert index["metadata"]["shared_1"] == "shared_2"
437+
438+
for filename in index["weight_map"].values():
439+
state_dict = load_file(tmp_path / filename)
440+
assert "shared_1" not in state_dict
441+
442+
443+
def test_save_torch_state_dict_discard_selected_not_sharded(
444+
tmp_path: Path, torch_state_dict_shared_layers: Dict[str, "torch.Tensor"]
445+
) -> None:
446+
from safetensors.torch import load_file
447+
448+
save_torch_state_dict(
449+
torch_state_dict_shared_layers,
450+
tmp_path,
451+
safe_serialization=True,
452+
shared_tensors_to_discard=["shared_1"],
453+
)
454+
safetensors_file = tmp_path / "model.safetensors"
455+
assert safetensors_file.is_file()
456+
457+
# Check shared layer not duplicated in file
458+
state_dict = load_file(safetensors_file)
459+
assert "shared_1" not in state_dict
460+
assert "shared_2" in state_dict
461+
462+
# Check shared layer info in metadata
463+
file_bytes = safetensors_file.read_bytes()
464+
metadata_str = file_bytes[
465+
8 : struct.unpack("<Q", file_bytes[:8])[0] + 8
466+
].decode() # TODO: next time add helper for this
467+
assert json.loads(metadata_str)["__metadata__"]["shared_1"] == "shared_2"
468+
469+
417470
def test_split_torch_state_dict_into_shards(
418471
tmp_path: Path, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"]
419472
):

0 commit comments

Comments
 (0)