Skip to content

Commit e8aa61b

Browse files
committed
handle shared tensors
1 parent ec53008 commit e8aa61b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,10 @@ def save_pretrained(
714714
if safe_serialization:
715715
# At some point we will need to deal better with save_function (used for TPU and other distributed
716716
# joyfulness), but for now this enough.
717-
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
717+
try:
718+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
719+
except RuntimeError:
720+
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
718721
else:
719722
torch.save(shard, filepath)
720723

0 commit comments

Comments
 (0)