You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Saves a given torch model to disk, handling sharding and shared tensors issues.
@@ -64,6 +66,12 @@ def save_torch_model(
64
66
65
67
</Tip>
66
68
69
+
<Tip warning={true}>
70
+
71
+
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
72
+
73
+
</Tip>
74
+
67
75
Args:
68
76
model (`torch.nn.Module`):
69
77
The model to save on disk.
@@ -88,6 +96,13 @@ def save_torch_model(
88
96
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
89
97
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
90
98
in a future version.
99
+
is_main_process (`bool`, *optional*):
100
+
Whether the process calling this is the main process or not. Useful when in distributed training like
101
+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
102
+
the main process to avoid race conditions. Defaults to True.
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
@@ -147,6 +166,12 @@ def save_torch_state_dict(
147
166
148
167
</Tip>
149
168
169
+
<Tip warning={true}>
170
+
171
+
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
172
+
173
+
</Tip>
174
+
150
175
Args:
151
176
state_dict (`Dict[str, torch.Tensor]`):
152
177
The state dictionary to save.
@@ -171,6 +196,13 @@ def save_torch_state_dict(
171
196
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
172
197
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
173
198
in a future version.
199
+
is_main_process (`bool`, *optional*):
200
+
Whether the process calling this is the main process or not. Useful when in distributed training like
201
+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
202
+
the main process to avoid race conditions. Defaults to True.
0 commit comments