Skip to content

Commit d67b99e

Browse files
authored
Merge branch 'main' into match-assertions-big-runner
2 parents 86260b2 + 52c05bd commit d67b99e

File tree

5 files changed

+28
-5
lines changed

5 files changed

+28
-5
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def load_single_file_sub_model(
6060
local_files_only=False,
6161
torch_dtype=None,
6262
is_legacy_loading=False,
63+
disable_mmap=False,
6364
**kwargs,
6465
):
6566
if is_pipeline_module:
@@ -106,6 +107,7 @@ def load_single_file_sub_model(
106107
subfolder=name,
107108
torch_dtype=torch_dtype,
108109
local_files_only=local_files_only,
110+
disable_mmap=disable_mmap,
109111
**kwargs,
110112
)
111113

@@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
308310
hosted on the Hub.
309311
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
310312
component configs in Diffusers format.
313+
disable_mmap ('bool', *optional*, defaults to 'False'):
314+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
315+
is on a network mount or hard drive.
311316
kwargs (remaining dictionary of keyword arguments, *optional*):
312317
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
313318
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
355360
local_files_only = kwargs.pop("local_files_only", False)
356361
revision = kwargs.pop("revision", None)
357362
torch_dtype = kwargs.pop("torch_dtype", None)
363+
disable_mmap = kwargs.pop("disable_mmap", False)
358364

359365
is_legacy_loading = False
360366

@@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
383389
cache_dir=cache_dir,
384390
local_files_only=local_files_only,
385391
revision=revision,
392+
disable_mmap=disable_mmap,
386393
)
387394

388395
if config is None:
@@ -504,6 +511,7 @@ def load_module(name, value):
504511
original_config=original_config,
505512
local_files_only=local_files_only,
506513
is_legacy_loading=is_legacy_loading,
514+
disable_mmap=disable_mmap,
507515
**kwargs,
508516
)
509517
except SingleFileComponentError as e:

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
187187
revision (`str`, *optional*, defaults to `"main"`):
188188
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
189189
allowed by Git.
190+
disable_mmap ('bool', *optional*, defaults to 'False'):
191+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
192+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
190193
kwargs (remaining dictionary of keyword arguments, *optional*):
191194
Can be used to overwrite load and saveable variables (for example the pipeline components of the
192195
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -234,6 +237,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
234237
torch_dtype = kwargs.pop("torch_dtype", None)
235238
quantization_config = kwargs.pop("quantization_config", None)
236239
device = kwargs.pop("device", None)
240+
disable_mmap = kwargs.pop("disable_mmap", False)
237241

238242
if isinstance(pretrained_model_link_or_path_or_dict, dict):
239243
checkpoint = pretrained_model_link_or_path_or_dict
@@ -246,6 +250,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
246250
cache_dir=cache_dir,
247251
local_files_only=local_files_only,
248252
revision=revision,
253+
disable_mmap=disable_mmap,
249254
)
250255
if quantization_config is not None:
251256
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)

src/diffusers/loaders/single_file_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def load_single_file_checkpoint(
387387
cache_dir=None,
388388
local_files_only=None,
389389
revision=None,
390+
disable_mmap=False,
390391
):
391392
if os.path.isfile(pretrained_model_link_or_path):
392393
pretrained_model_link_or_path = pretrained_model_link_or_path
@@ -404,7 +405,7 @@ def load_single_file_checkpoint(
404405
revision=revision,
405406
)
406407

407-
checkpoint = load_state_dict(pretrained_model_link_or_path)
408+
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
408409

409410
# some checkpoints contain the model state dict under a "state_dict" key
410411
while "state_dict" in checkpoint:

src/diffusers/models/model_loading_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
131131
return old_class
132132

133133

134-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
134+
def load_state_dict(
135+
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
136+
):
135137
"""
136138
Reads a checkpoint file, returning properly formatted errors if they arise.
137139
"""
@@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
142144
try:
143145
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
144146
if file_extension == SAFETENSORS_FILE_EXTENSION:
145-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
147+
if disable_mmap:
148+
return safetensors.torch.load(open(checkpoint_file, "rb").read())
149+
else:
150+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
146151
elif file_extension == GGUF_FILE_EXTENSION:
147152
return load_gguf_checkpoint(checkpoint_file)
148153
else:

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
559559
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
560560
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
561561
weights. If set to `False`, `safetensors` weights are not loaded.
562+
disable_mmap ('bool', *optional*, defaults to 'False'):
563+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
564+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
562565
563566
<Tip>
564567
@@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
604607
variant = kwargs.pop("variant", None)
605608
use_safetensors = kwargs.pop("use_safetensors", None)
606609
quantization_config = kwargs.pop("quantization_config", None)
610+
disable_mmap = kwargs.pop("disable_mmap", False)
607611

608612
allow_pickle = False
609613
if use_safetensors is None:
@@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
883887
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
884888
else:
885889
param_device = torch.device(torch.cuda.current_device())
886-
state_dict = load_state_dict(model_file, variant=variant)
890+
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
887891
model._convert_deprecated_attention_blocks(state_dict)
888892

889893
# move the params from meta device to cpu
@@ -979,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
979983
else:
980984
model = cls.from_config(config, **unused_kwargs)
981985

982-
state_dict = load_state_dict(model_file, variant=variant)
986+
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
983987
model._convert_deprecated_attention_blocks(state_dict)
984988

985989
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(

0 commit comments

Comments
 (0)