Skip to content

Commit 0446652

Browse files
committed
save_pretrained
1 parent 55ea7ce commit 0446652

File tree

1 file changed

+67
-51
lines changed

1 file changed

+67
-51
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 67 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ def save_pretrained(
648648
variant: Optional[str] = None,
649649
max_shard_size: Union[int, str] = "10GB",
650650
push_to_hub: bool = False,
651+
use_flashpack: bool = False,
651652
**kwargs,
652653
):
653654
"""
@@ -700,7 +701,12 @@ def save_pretrained(
700701
" the logger on the traceback to understand the reason why the quantized model is not serializable."
701702
)
702703

703-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
704+
weights_name = WEIGHTS_NAME
705+
if use_flashpack:
706+
weights_name = FLASHPACK_WEIGHTS_NAME
707+
elif safe_serialization:
708+
weights_name = SAFETENSORS_WEIGHTS_NAME
709+
704710
weights_name = _add_variant(weights_name, variant)
705711
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
706712
".safetensors", "{suffix}.safetensors"
@@ -727,58 +733,68 @@ def save_pretrained(
727733
# Save the model
728734
state_dict = model_to_save.state_dict()
729735

730-
# Save the model
731-
state_dict_split = split_torch_state_dict_into_shards(
732-
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
733-
)
734-
735-
# Clean the folder from a previous save
736-
if is_main_process:
737-
for filename in os.listdir(save_directory):
738-
if filename in state_dict_split.filename_to_tensors.keys():
739-
continue
740-
full_filename = os.path.join(save_directory, filename)
741-
if not os.path.isfile(full_filename):
742-
continue
743-
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
744-
weights_without_ext = weights_without_ext.replace("{suffix}", "")
745-
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
746-
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
747-
if (
748-
filename.startswith(weights_without_ext)
749-
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
750-
):
751-
os.remove(full_filename)
752-
753-
for filename, tensors in state_dict_split.filename_to_tensors.items():
754-
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
755-
filepath = os.path.join(save_directory, filename)
756-
if safe_serialization:
757-
# At some point we will need to deal better with save_function (used for TPU and other distributed
758-
# joyfulness), but for now this enough.
759-
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
760-
else:
761-
torch.save(shard, filepath)
736+
if use_flashpack:
737+
if is_flashpack_available():
738+
import flashpack
762739

763-
if state_dict_split.is_sharded:
764-
index = {
765-
"metadata": state_dict_split.metadata,
766-
"weight_map": state_dict_split.tensor_to_filename,
767-
}
768-
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
769-
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
770-
# Save the index as well
771-
with open(save_index_file, "w", encoding="utf-8") as f:
772-
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
773-
f.write(content)
774-
logger.info(
775-
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
776-
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
777-
f"index located at {save_index_file}."
778-
)
740+
flashpack.serialization.pack_to_file(
741+
state_dict_or_model=state_dict,
742+
destination_path=save_directory,
743+
target_dtype=self.dtype(),
744+
)
779745
else:
780-
path_to_weights = os.path.join(save_directory, weights_name)
781-
logger.info(f"Model weights saved in {path_to_weights}")
746+
# Save the model
747+
state_dict_split = split_torch_state_dict_into_shards(
748+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
749+
)
750+
751+
# Clean the folder from a previous save
752+
if is_main_process:
753+
for filename in os.listdir(save_directory):
754+
if filename in state_dict_split.filename_to_tensors.keys():
755+
continue
756+
full_filename = os.path.join(save_directory, filename)
757+
if not os.path.isfile(full_filename):
758+
continue
759+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
760+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
761+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
762+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
763+
if (
764+
filename.startswith(weights_without_ext)
765+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
766+
):
767+
os.remove(full_filename)
768+
769+
for filename, tensors in state_dict_split.filename_to_tensors.items():
770+
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
771+
filepath = os.path.join(save_directory, filename)
772+
if safe_serialization:
773+
# At some point we will need to deal better with save_function (used for TPU and other distributed
774+
# joyfulness), but for now this enough.
775+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
776+
else:
777+
torch.save(shard, filepath)
778+
779+
if state_dict_split.is_sharded:
780+
index = {
781+
"metadata": state_dict_split.metadata,
782+
"weight_map": state_dict_split.tensor_to_filename,
783+
}
784+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
785+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
786+
# Save the index as well
787+
with open(save_index_file, "w", encoding="utf-8") as f:
788+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
789+
f.write(content)
790+
logger.info(
791+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
792+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
793+
f"index located at {save_index_file}."
794+
)
795+
else:
796+
path_to_weights = os.path.join(save_directory, weights_name)
797+
logger.info(f"Model weights saved in {path_to_weights}")
782798

783799
if push_to_hub:
784800
# Create a new empty model card and eventually tag it

0 commit comments

Comments
 (0)