diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index aa6da17edfe7..dffd49cb0ce7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -189,6 +189,7 @@ def save_pretrained( save_directory: Union[str, os.PathLike], safe_serialization: bool = True, variant: Optional[str] = None, + max_shard_size: Union[int, str] = "10GB", push_to_hub: bool = False, **kwargs, ): @@ -204,6 +205,13 @@ class implements both a save and loading method. The pipeline is easily reloaded Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain + period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. + This is to establish a common default size for this argument across different libraries in the Hugging + Face ecosystem (`transformers`, and `accelerate`, for example). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your @@ -278,12 +286,15 @@ def is_saveable_module(name, value): save_method_signature = inspect.signature(save_method) save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters + save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: save_kwargs["safe_serialization"] = safe_serialization if save_method_accept_variant: save_kwargs["variant"] = variant + if save_method_accept_max_shard_size: + save_kwargs["max_shard_size"] = max_shard_size save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)