From e8184f21867e31d1b87d89578e4b191a1875aa80 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Sep 2024 00:30:42 +0200 Subject: [PATCH 1/2] update default max_shard_size --- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index dffd49cb0ce7..18cf252974f5 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -189,7 +189,7 @@ def save_pretrained( save_directory: Union[str, os.PathLike], safe_serialization: bool = True, variant: Optional[str] = None, - max_shard_size: Union[int, str] = "10GB", + max_shard_size: Optional[Union[int, str]] = None, push_to_hub: bool = False, **kwargs, ): @@ -205,7 +205,7 @@ class implements both a save and loading method. The pipeline is easily reloaded Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. - max_shard_size (`int` or `str`, defaults to `"10GB"`): + max_shard_size (`int` or `str`, defaults to `None`): The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain From c77c6648d5f5d1f0611f033f962bdc8c85e48b4b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Sep 2024 03:45:33 +0200 Subject: [PATCH 2/2] add None check to fix tests --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 18cf252974f5..ccd1c9485d0e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -293,7 +293,8 @@ def is_saveable_module(name, value): save_kwargs["safe_serialization"] = safe_serialization if save_method_accept_variant: save_kwargs["variant"] = variant - if save_method_accept_max_shard_size: + if save_method_accept_max_shard_size and max_shard_size is not None: + # max_shard_size is expected to not be None in ModelMixin save_kwargs["max_shard_size"] = max_shard_size save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)