Skip to content

Commit f80644d

Browse files
committed
Rename to disable_mmap and update other references.
1 parent 3cc50f0 commit f80644d

File tree

5 files changed

+24
-16
lines changed

5 files changed

+24
-16
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def load_single_file_sub_model(
6060
local_files_only=False,
6161
torch_dtype=None,
6262
is_legacy_loading=False,
63-
no_mmap=False,
63+
disable_mmap=False,
6464
**kwargs,
6565
):
6666
if is_pipeline_module:
@@ -107,7 +107,7 @@ def load_single_file_sub_model(
107107
subfolder=name,
108108
torch_dtype=torch_dtype,
109109
local_files_only=local_files_only,
110-
no_mmap=no_mmap,
110+
disable_mmap=disable_mmap,
111111
**kwargs,
112112
)
113113

@@ -310,7 +310,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
310310
hosted on the Hub.
311311
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
312312
component configs in Diffusers format.
313-
no_mmap ('bool', *optional*, defaults to 'False'):
313+
disable_mmap ('bool', *optional*, defaults to 'False'):
314314
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
315315
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
316316
kwargs (remaining dictionary of keyword arguments, *optional*):
@@ -360,7 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
360360
local_files_only = kwargs.pop("local_files_only", False)
361361
revision = kwargs.pop("revision", None)
362362
torch_dtype = kwargs.pop("torch_dtype", None)
363-
no_mmap = kwargs.pop("no_mmap", False)
363+
disable_mmap = kwargs.pop("disable_mmap", False)
364364

365365
is_legacy_loading = False
366366

@@ -389,7 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
389389
cache_dir=cache_dir,
390390
local_files_only=local_files_only,
391391
revision=revision,
392-
no_mmap=no_mmap,
392+
disable_mmap=disable_mmap,
393393
)
394394

395395
if config is None:
@@ -511,7 +511,7 @@ def load_module(name, value):
511511
original_config=original_config,
512512
local_files_only=local_files_only,
513513
is_legacy_loading=is_legacy_loading,
514-
no_mmap=no_mmap,
514+
disable_mmap=disable_mmap,
515515
**kwargs,
516516
)
517517
except SingleFileComponentError as e:

src/diffusers/loaders/single_file_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
177177
revision (`str`, *optional*, defaults to `"main"`):
178178
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
179179
allowed by Git.
180-
no_mmap ('bool', *optional*, defaults to 'False'):
180+
disable_mmap ('bool', *optional*, defaults to 'False'):
181181
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
182182
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
183183
kwargs (remaining dictionary of keyword arguments, *optional*):
@@ -226,7 +226,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
226226
torch_dtype = kwargs.pop("torch_dtype", None)
227227
quantization_config = kwargs.pop("quantization_config", None)
228228
device = kwargs.pop("device", None)
229-
no_mmap = kwargs.pop("no_mmap", False)
229+
disable_mmap = kwargs.pop("disable_mmap", False)
230230

231231
if isinstance(pretrained_model_link_or_path_or_dict, dict):
232232
checkpoint = pretrained_model_link_or_path_or_dict
@@ -239,7 +239,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
239239
cache_dir=cache_dir,
240240
local_files_only=local_files_only,
241241
revision=revision,
242-
no_mmap=no_mmap,
242+
disable_mmap=disable_mmap,
243243
)
244244
if quantization_config is not None:
245245
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
@@ -361,7 +361,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
361361
)
362362

363363
else:
364-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False, no_mmap=no_mmap)
364+
_, unexpected_keys = model.load_state_dict(
365+
diffusers_format_checkpoint, strict=False, disable_mmap=disable_mmap
366+
)
365367

366368
if model._keys_to_ignore_on_load_unexpected is not None:
367369
for pat in model._keys_to_ignore_on_load_unexpected:

src/diffusers/loaders/single_file_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def load_single_file_checkpoint(
375375
cache_dir=None,
376376
local_files_only=None,
377377
revision=None,
378-
no_mmap=False,
378+
disable_mmap=False,
379379
):
380380
if os.path.isfile(pretrained_model_link_or_path):
381381
pretrained_model_link_or_path = pretrained_model_link_or_path
@@ -393,7 +393,7 @@ def load_single_file_checkpoint(
393393
revision=revision,
394394
)
395395

396-
checkpoint = load_state_dict(pretrained_model_link_or_path, no_mmap=no_mmap)
396+
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
397397

398398
# some checkpoints contain the model state dict under a "state_dict" key
399399
while "state_dict" in checkpoint:

src/diffusers/models/model_loading_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
131131
return old_class
132132

133133

134-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, no_mmap: bool = False):
134+
def load_state_dict(
135+
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
136+
):
135137
"""
136138
Reads a checkpoint file, returning properly formatted errors if they arise.
137139
"""
@@ -142,7 +144,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
142144
try:
143145
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
144146
if file_extension == SAFETENSORS_FILE_EXTENSION:
145-
if no_mmap:
147+
if disable_mmap:
146148
return safetensors.torch.load(open(checkpoint_file, "rb").read())
147149
else:
148150
return safetensors.torch.load_file(checkpoint_file, device="cpu")

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
541541
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
542542
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
543543
weights. If set to `False`, `safetensors` weights are not loaded.
544+
disable_mmap ('bool', *optional*, defaults to 'False'):
545+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
546+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
544547
545548
<Tip>
546549
@@ -586,6 +589,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
586589
variant = kwargs.pop("variant", None)
587590
use_safetensors = kwargs.pop("use_safetensors", None)
588591
quantization_config = kwargs.pop("quantization_config", None)
592+
disable_mmap = kwargs.pop("disable_mmap", False)
589593

590594
allow_pickle = False
591595
if use_safetensors is None:
@@ -865,7 +869,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
865869
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
866870
else:
867871
param_device = torch.device(torch.cuda.current_device())
868-
state_dict = load_state_dict(model_file, variant=variant)
872+
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
869873
model._convert_deprecated_attention_blocks(state_dict)
870874

871875
# move the params from meta device to cpu
@@ -965,7 +969,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
965969
else:
966970
model = cls.from_config(config, **unused_kwargs)
967971

968-
state_dict = load_state_dict(model_file, variant=variant)
972+
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
969973
model._convert_deprecated_attention_blocks(state_dict)
970974

971975
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(

0 commit comments

Comments
 (0)