Skip to content

Commit 343c5cb

Browse files
authored
Merge branch 'main' into main
2 parents 8fe6408 + 9f06a0d commit 343c5cb

File tree

11 files changed

+68
-36
lines changed

11 files changed

+68
-36
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def load_single_file_sub_model(
6060
local_files_only=False,
6161
torch_dtype=None,
6262
is_legacy_loading=False,
63+
disable_mmap=False,
6364
**kwargs,
6465
):
6566
if is_pipeline_module:
@@ -106,6 +107,7 @@ def load_single_file_sub_model(
106107
subfolder=name,
107108
torch_dtype=torch_dtype,
108109
local_files_only=local_files_only,
110+
disable_mmap=disable_mmap,
109111
**kwargs,
110112
)
111113

@@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
308310
hosted on the Hub.
309311
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
310312
component configs in Diffusers format.
313+
disable_mmap ('bool', *optional*, defaults to 'False'):
314+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
315+
is on a network mount or hard drive.
311316
kwargs (remaining dictionary of keyword arguments, *optional*):
312317
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
313318
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
355360
local_files_only = kwargs.pop("local_files_only", False)
356361
revision = kwargs.pop("revision", None)
357362
torch_dtype = kwargs.pop("torch_dtype", None)
363+
disable_mmap = kwargs.pop("disable_mmap", False)
358364

359365
is_legacy_loading = False
360366

@@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
383389
cache_dir=cache_dir,
384390
local_files_only=local_files_only,
385391
revision=revision,
392+
disable_mmap=disable_mmap,
386393
)
387394

388395
if config is None:
@@ -504,6 +511,7 @@ def load_module(name, value):
504511
original_config=original_config,
505512
local_files_only=local_files_only,
506513
is_legacy_loading=is_legacy_loading,
514+
disable_mmap=disable_mmap,
507515
**kwargs,
508516
)
509517
except SingleFileComponentError as e:

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
187187
revision (`str`, *optional*, defaults to `"main"`):
188188
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
189189
allowed by Git.
190+
disable_mmap ('bool', *optional*, defaults to 'False'):
191+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
192+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
190193
kwargs (remaining dictionary of keyword arguments, *optional*):
191194
Can be used to overwrite load and saveable variables (for example the pipeline components of the
192195
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -234,6 +237,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
234237
torch_dtype = kwargs.pop("torch_dtype", None)
235238
quantization_config = kwargs.pop("quantization_config", None)
236239
device = kwargs.pop("device", None)
240+
disable_mmap = kwargs.pop("disable_mmap", False)
237241

238242
if isinstance(pretrained_model_link_or_path_or_dict, dict):
239243
checkpoint = pretrained_model_link_or_path_or_dict
@@ -246,6 +250,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
246250
cache_dir=cache_dir,
247251
local_files_only=local_files_only,
248252
revision=revision,
253+
disable_mmap=disable_mmap,
249254
)
250255
if quantization_config is not None:
251256
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)

src/diffusers/loaders/single_file_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def load_single_file_checkpoint(
387387
cache_dir=None,
388388
local_files_only=None,
389389
revision=None,
390+
disable_mmap=False,
390391
):
391392
if os.path.isfile(pretrained_model_link_or_path):
392393
pretrained_model_link_or_path = pretrained_model_link_or_path
@@ -404,7 +405,7 @@ def load_single_file_checkpoint(
404405
revision=revision,
405406
)
406407

407-
checkpoint = load_state_dict(pretrained_model_link_or_path)
408+
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
408409

409410
# some checkpoints contain the model state dict under a "state_dict" key
410411
while "state_dict" in checkpoint:

src/diffusers/models/model_loading_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class):
131131
return old_class
132132

133133

134-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
134+
def load_state_dict(
135+
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
136+
):
135137
"""
136138
Reads a checkpoint file, returning properly formatted errors if they arise.
137139
"""
@@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
142144
try:
143145
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
144146
if file_extension == SAFETENSORS_FILE_EXTENSION:
145-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
147+
if disable_mmap:
148+
return safetensors.torch.load(open(checkpoint_file, "rb").read())
149+
else:
150+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
146151
elif file_extension == GGUF_FILE_EXTENSION:
147152
return load_gguf_checkpoint(checkpoint_file)
148153
else:

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
559559
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
560560
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
561561
weights. If set to `False`, `safetensors` weights are not loaded.
562+
disable_mmap ('bool', *optional*, defaults to 'False'):
563+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
564+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
562565
563566
<Tip>
564567
@@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
604607
variant = kwargs.pop("variant", None)
605608
use_safetensors = kwargs.pop("use_safetensors", None)
606609
quantization_config = kwargs.pop("quantization_config", None)
610+
disable_mmap = kwargs.pop("disable_mmap", False)
607611

608612
allow_pickle = False
609613
if use_safetensors is None:
@@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
883887
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
884888
else:
885889
param_device = torch.device(torch.cuda.current_device())
886-
state_dict = load_state_dict(model_file, variant=variant)
890+
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
887891
model._convert_deprecated_attention_blocks(state_dict)
888892

889893
# move the params from meta device to cpu
@@ -979,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
979983
else:
980984
model = cls.from_config(config, **unused_kwargs)
981985

982-
state_dict = load_state_dict(model_file, variant=variant)
986+
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
983987
model._convert_deprecated_attention_blocks(state_dict)
984988

985989
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(

tests/lora/test_lora_layers_flux.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -796,8 +796,8 @@ def test_modify_padding_mode(self):
796796
@nightly
797797
@require_torch_gpu
798798
@require_peft_backend
799-
@unittest.skip("We cannot run inference on this model with the current CI hardware")
800-
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
799+
@require_big_gpu_with_torch_cuda
800+
@pytest.mark.big_gpu_with_torch_cuda
801801
class FluxLoRAIntegrationTests(unittest.TestCase):
802802
"""internal note: The integration slices were obtained on audace.
803803
@@ -819,14 +819,18 @@ def setUp(self):
819819
def tearDown(self):
820820
super().tearDown()
821821

822+
del self.pipeline
822823
gc.collect()
823824
torch.cuda.empty_cache()
824825

825826
def test_flux_the_last_ben(self):
826827
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
827828
self.pipeline.fuse_lora()
828829
self.pipeline.unload_lora_weights()
829-
self.pipeline.enable_model_cpu_offload()
830+
# Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
831+
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
832+
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
833+
self.pipeline = self.pipeline.to(torch_device)
830834

831835
prompt = "jon snow eating pizza with ketchup"
832836

@@ -848,7 +852,7 @@ def test_flux_kohya(self):
848852
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
849853
self.pipeline.fuse_lora()
850854
self.pipeline.unload_lora_weights()
851-
self.pipeline.enable_model_cpu_offload()
855+
self.pipeline = self.pipeline.to(torch_device)
852856

853857
prompt = "The cat with a brain slug earring"
854858
out = self.pipeline(
@@ -870,7 +874,7 @@ def test_flux_kohya_with_text_encoder(self):
870874
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
871875
self.pipeline.fuse_lora()
872876
self.pipeline.unload_lora_weights()
873-
self.pipeline.enable_model_cpu_offload()
877+
self.pipeline = self.pipeline.to(torch_device)
874878

875879
prompt = "optimus is cleaning the house with broomstick"
876880
out = self.pipeline(
@@ -892,7 +896,7 @@ def test_flux_xlabs(self):
892896
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
893897
self.pipeline.fuse_lora()
894898
self.pipeline.unload_lora_weights()
895-
self.pipeline.enable_model_cpu_offload()
899+
self.pipeline = self.pipeline.to(torch_device)
896900

897901
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
898902

tests/lora/test_lora_layers_sd3.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
import numpy as np
20+
import pytest
2021
import torch
2122
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
2223

@@ -31,9 +32,9 @@
3132
from diffusers.utils.testing_utils import (
3233
nightly,
3334
numpy_cosine_similarity_distance,
35+
require_big_gpu_with_torch_cuda,
3436
require_peft_backend,
3537
require_torch_gpu,
36-
slow,
3738
torch_device,
3839
)
3940

@@ -128,11 +129,12 @@ def test_modify_padding_mode(self):
128129
pass
129130

130131

131-
@slow
132132
@nightly
133133
@require_torch_gpu
134134
@require_peft_backend
135-
class LoraSD3IntegrationTests(unittest.TestCase):
135+
@require_big_gpu_with_torch_cuda
136+
@pytest.mark.big_gpu_with_torch_cuda
137+
class SD3LoraIntegrationTests(unittest.TestCase):
136138
pipeline_class = StableDiffusion3Img2ImgPipeline
137139
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
138140

@@ -166,14 +168,17 @@ def get_inputs(self, device, seed=0):
166168

167169
def test_sd3_img2img_lora(self):
168170
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
169-
pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2", weight_name="pytorch_lora_weights.safetensors")
170-
pipe.enable_sequential_cpu_offload()
171+
pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2")
172+
pipe.fuse_lora()
173+
pipe.unload_lora_weights()
174+
pipe = pipe.to(torch_device)
171175

172176
inputs = self.get_inputs(torch_device)
173177

174178
image = pipe(**inputs).images[0]
175179
image_slice = image[0, -3:, -3:]
176-
expected_slice = np.array([0.5396, 0.5776, 0.7432, 0.5151, 0.5586, 0.7383, 0.5537, 0.5933, 0.7153])
180+
expected_slice = np.array([0.5649, 0.5405, 0.5488, 0.5688, 0.5449, 0.5513, 0.5337, 0.5107, 0.5059])
181+
177182
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
178183

179184
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"

tests/pipelines/controlnet_flux/test_controlnet_flux.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from diffusers.utils import load_image
3333
from diffusers.utils.testing_utils import (
3434
enable_full_determinism,
35+
nightly,
3536
numpy_cosine_similarity_distance,
3637
require_big_gpu_with_torch_cuda,
37-
slow,
3838
torch_device,
3939
)
4040
from diffusers.utils.torch_utils import randn_tensor
@@ -204,7 +204,7 @@ def test_flux_image_output_shape(self):
204204
assert (output_height, output_width) == (expected_height, expected_width)
205205

206206

207-
@slow
207+
@nightly
208208
@require_big_gpu_with_torch_cuda
209209
@pytest.mark.big_gpu_with_torch_cuda
210210
class FluxControlNetPipelineSlowTests(unittest.TestCase):
@@ -230,8 +230,7 @@ def test_canny(self):
230230
text_encoder_2=None,
231231
controlnet=controlnet,
232232
torch_dtype=torch.bfloat16,
233-
)
234-
pipe.enable_model_cpu_offload()
233+
).to(torch_device)
235234
pipe.set_progress_bar_config(disable=None)
236235

237236
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -241,12 +240,12 @@ def test_canny(self):
241240

242241
prompt_embeds = torch.load(
243242
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
244-
)
243+
).to(torch_device)
245244
pooled_prompt_embeds = torch.load(
246245
hf_hub_download(
247246
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
248247
)
249-
)
248+
).to(torch_device)
250249

251250
output = pipe(
252251
prompt_embeds=prompt_embeds,

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
1111
from diffusers.utils.testing_utils import (
12+
nightly,
1213
numpy_cosine_similarity_distance,
1314
require_big_gpu_with_torch_cuda,
1415
slow,
@@ -209,7 +210,7 @@ def test_flux_image_output_shape(self):
209210
assert (output_height, output_width) == (expected_height, expected_width)
210211

211212

212-
@slow
213+
@nightly
213214
@require_big_gpu_with_torch_cuda
214215
@pytest.mark.big_gpu_with_torch_cuda
215216
class FluxPipelineSlowTests(unittest.TestCase):
@@ -227,19 +228,16 @@ def tearDown(self):
227228
torch.cuda.empty_cache()
228229

229230
def get_inputs(self, device, seed=0):
230-
if str(device).startswith("mps"):
231-
generator = torch.manual_seed(seed)
232-
else:
233-
generator = torch.Generator(device="cpu").manual_seed(seed)
231+
generator = torch.Generator(device="cpu").manual_seed(seed)
234232

235233
prompt_embeds = torch.load(
236234
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
237-
)
235+
).to(torch_device)
238236
pooled_prompt_embeds = torch.load(
239237
hf_hub_download(
240238
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
241239
)
242-
)
240+
).to(torch_device)
243241
return {
244242
"prompt_embeds": prompt_embeds,
245243
"pooled_prompt_embeds": pooled_prompt_embeds,
@@ -253,8 +251,7 @@ def get_inputs(self, device, seed=0):
253251
def test_flux_inference(self):
254252
pipe = self.pipeline_class.from_pretrained(
255253
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
256-
)
257-
pipe.enable_model_cpu_offload()
254+
).to(torch_device)
258255

259256
inputs = self.get_inputs(torch_device)
260257

tests/pipelines/mochi/test_mochi.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717
import unittest
1818

1919
import numpy as np
20+
import pytest
2021
import torch
2122
from transformers import AutoTokenizer, T5EncoderModel
2223

2324
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
2425
from diffusers.utils.testing_utils import (
2526
enable_full_determinism,
27+
nightly,
2628
numpy_cosine_similarity_distance,
29+
require_big_gpu_with_torch_cuda,
2730
require_torch_gpu,
28-
slow,
2931
torch_device,
3032
)
3133

@@ -260,8 +262,10 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
260262
)
261263

262264

263-
@slow
265+
@nightly
264266
@require_torch_gpu
267+
@require_big_gpu_with_torch_cuda
268+
@pytest.mark.big_gpu_with_torch_cuda
265269
class MochiPipelineIntegrationTests(unittest.TestCase):
266270
prompt = "A painting of a squirrel eating a burger."
267271

@@ -293,7 +297,7 @@ def test_mochi(self):
293297
).frames
294298

295299
video = videos[0]
296-
expected_video = torch.randn(1, 16, 480, 848, 3).numpy()
300+
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
297301

298302
max_diff = numpy_cosine_similarity_distance(video, expected_video)
299303
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

0 commit comments

Comments
 (0)