Skip to content

Commit 30f8f74

Browse files
committed
Merge branch 'main' into support-comyui-flux-loras
2 parents c30a1e4 + 2c1ed50 commit 30f8f74

26 files changed

+208
-113
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from contextlib import nullcontext
15+
from contextlib import contextmanager, nullcontext
1616
from typing import Dict, List, Optional, Set, Tuple
1717

1818
import torch
@@ -56,23 +56,58 @@ def __init__(
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
59-
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
59+
low_cpu_mem_usage=False,
6060
onload_self: bool = True,
6161
) -> None:
6262
self.modules = modules
6363
self.offload_device = offload_device
6464
self.onload_device = onload_device
6565
self.offload_leader = offload_leader
6666
self.onload_leader = onload_leader
67-
self.parameters = parameters
68-
self.buffers = buffers
67+
self.parameters = parameters or []
68+
self.buffers = buffers or []
6969
self.non_blocking = non_blocking or stream is not None
7070
self.stream = stream
71-
self.cpu_param_dict = cpu_param_dict
7271
self.onload_self = onload_self
72+
self.low_cpu_mem_usage = low_cpu_mem_usage
7373

74-
if self.stream is not None and self.cpu_param_dict is None:
75-
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
74+
self.cpu_param_dict = self._init_cpu_param_dict()
75+
76+
def _init_cpu_param_dict(self):
77+
cpu_param_dict = {}
78+
if self.stream is None:
79+
return cpu_param_dict
80+
81+
for module in self.modules:
82+
for param in module.parameters():
83+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
84+
for buffer in module.buffers():
85+
cpu_param_dict[buffer] = (
86+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
87+
)
88+
89+
for param in self.parameters:
90+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
91+
92+
for buffer in self.buffers:
93+
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
94+
95+
return cpu_param_dict
96+
97+
@contextmanager
98+
def _pinned_memory_tensors(self):
99+
pinned_dict = {}
100+
try:
101+
for param, tensor in self.cpu_param_dict.items():
102+
if not tensor.is_pinned():
103+
pinned_dict[param] = tensor.pin_memory()
104+
else:
105+
pinned_dict[param] = tensor
106+
107+
yield pinned_dict
108+
109+
finally:
110+
pinned_dict = None
76111

77112
def onload_(self):
78113
r"""Onloads the group of modules to the onload_device."""
@@ -82,15 +117,30 @@ def onload_(self):
82117
self.stream.synchronize()
83118

84119
with context:
85-
for group_module in self.modules:
86-
for param in group_module.parameters():
87-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
88-
for buffer in group_module.buffers():
89-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
90-
if self.parameters is not None:
120+
if self.stream is not None:
121+
with self._pinned_memory_tensors() as pinned_memory:
122+
for group_module in self.modules:
123+
for param in group_module.parameters():
124+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
125+
for buffer in group_module.buffers():
126+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
127+
128+
for param in self.parameters:
129+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
130+
131+
for buffer in self.buffers:
132+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
133+
134+
else:
135+
for group_module in self.modules:
136+
for param in group_module.parameters():
137+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
138+
for buffer in group_module.buffers():
139+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
140+
91141
for param in self.parameters:
92142
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
93-
if self.buffers is not None:
143+
94144
for buffer in self.buffers:
95145
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
96146

@@ -101,21 +151,18 @@ def offload_(self):
101151
for group_module in self.modules:
102152
for param in group_module.parameters():
103153
param.data = self.cpu_param_dict[param]
104-
if self.parameters is not None:
105-
for param in self.parameters:
106-
param.data = self.cpu_param_dict[param]
107-
if self.buffers is not None:
108-
for buffer in self.buffers:
109-
buffer.data = self.cpu_param_dict[buffer]
154+
for param in self.parameters:
155+
param.data = self.cpu_param_dict[param]
156+
for buffer in self.buffers:
157+
buffer.data = self.cpu_param_dict[buffer]
158+
110159
else:
111160
for group_module in self.modules:
112161
group_module.to(self.offload_device, non_blocking=self.non_blocking)
113-
if self.parameters is not None:
114-
for param in self.parameters:
115-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
116-
if self.buffers is not None:
117-
for buffer in self.buffers:
118-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
162+
for param in self.parameters:
163+
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
164+
for buffer in self.buffers:
165+
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
119166

120167

121168
class GroupOffloadingHook(ModelHook):
@@ -284,6 +331,7 @@ def apply_group_offloading(
284331
num_blocks_per_group: Optional[int] = None,
285332
non_blocking: bool = False,
286333
use_stream: bool = False,
334+
low_cpu_mem_usage=False,
287335
) -> None:
288336
r"""
289337
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +413,12 @@ def apply_group_offloading(
365413
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
366414

367415
_apply_group_offloading_block_level(
368-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
416+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
369417
)
370418
elif offload_type == "leaf_level":
371-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
419+
_apply_group_offloading_leaf_level(
420+
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
421+
)
372422
else:
373423
raise ValueError(f"Unsupported offload_type: {offload_type}")
374424

@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
380430
onload_device: torch.device,
381431
non_blocking: bool,
382432
stream: Optional[torch.cuda.Stream] = None,
433+
low_cpu_mem_usage: bool = False,
383434
) -> None:
384435
r"""
385436
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
400451
for overlapping computation and data transfer.
401452
"""
402453

403-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
404-
cpu_param_dict = None
405-
if stream is not None:
406-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
407-
408454
# Create module groups for ModuleList and Sequential blocks
409455
modules_with_group_offloading = set()
410456
unmatched_modules = []
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
425471
onload_leader=current_modules[0],
426472
non_blocking=non_blocking,
427473
stream=stream,
428-
cpu_param_dict=cpu_param_dict,
474+
low_cpu_mem_usage=low_cpu_mem_usage,
429475
onload_self=stream is None,
430476
)
431477
matched_module_groups.append(group)
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
462508
buffers=buffers,
463509
non_blocking=False,
464510
stream=None,
465-
cpu_param_dict=None,
466511
onload_self=True,
467512
)
468513
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
475520
onload_device: torch.device,
476521
non_blocking: bool,
477522
stream: Optional[torch.cuda.Stream] = None,
523+
low_cpu_mem_usage: bool = False,
478524
) -> None:
479525
r"""
480526
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
497543
for overlapping computation and data transfer.
498544
"""
499545

500-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
501-
cpu_param_dict = None
502-
if stream is not None:
503-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
504-
505546
# Create module groups for leaf modules and apply group offloading hooks
506547
modules_with_group_offloading = set()
507548
for name, submodule in module.named_modules():
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
515556
onload_leader=submodule,
516557
non_blocking=non_blocking,
517558
stream=stream,
518-
cpu_param_dict=cpu_param_dict,
559+
low_cpu_mem_usage=low_cpu_mem_usage,
519560
onload_self=True,
520561
)
521562
_apply_group_offloading_hook(submodule, group, None)
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
560601
buffers=buffers,
561602
non_blocking=non_blocking,
562603
stream=stream,
563-
cpu_param_dict=cpu_param_dict,
604+
low_cpu_mem_usage=low_cpu_mem_usage,
564605
onload_self=True,
565606
)
566607
_apply_group_offloading_hook(parent_module, group, None)
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
579620
buffers=None,
580621
non_blocking=False,
581622
stream=None,
582-
cpu_param_dict=None,
623+
low_cpu_mem_usage=low_cpu_mem_usage,
583624
onload_self=True,
584625
)
585626
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
616657
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
617658

618659

619-
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
620-
cpu_param_dict = {}
621-
for param in module.parameters():
622-
param.data = param.data.cpu().pin_memory()
623-
cpu_param_dict[param] = param.data
624-
for buffer in module.buffers():
625-
buffer.data = buffer.data.cpu().pin_memory()
626-
cpu_param_dict[buffer] = buffer.data
627-
return cpu_param_dict
628-
629-
630660
def _gather_parameters_with_no_group_offloading_parent(
631661
module: torch.nn.Module, modules_with_group_offloading: Set[str]
632662
) -> List[torch.nn.Parameter]:

src/diffusers/loaders/textual_inversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,9 @@ def load_textual_inversion(
449449

450450
# 7.5 Offload the model again
451451
if is_model_cpu_offload:
452-
self.enable_model_cpu_offload()
452+
self.enable_model_cpu_offload(device=device)
453453
elif is_sequential_cpu_offload:
454-
self.enable_sequential_cpu_offload()
454+
self.enable_sequential_cpu_offload(device=device)
455455

456456
# / Unsafe Code >
457457

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549+
low_cpu_mem_usage=False,
549550
) -> None:
550551
r"""
551552
Activates group offloading for the current model.
@@ -584,7 +585,14 @@ def enable_group_offload(
584585
f"open an issue at https://github.com/huggingface/diffusers/issues."
585586
)
586587
apply_group_offloading(
587-
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
588+
self,
589+
onload_device,
590+
offload_device,
591+
offload_type,
592+
num_blocks_per_group,
593+
non_blocking,
594+
use_stream,
595+
low_cpu_mem_usage=low_cpu_mem_usage,
588596
)
589597

590598
def save_pretrained(

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
533533

534534
return latents
535535

536-
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
537536
def prepare_latents(
538537
self,
539538
image,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
533533

534534
return latents
535535

536-
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
537536
def prepare_latents(
538537
self,
539538
image,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
561561

562562
return latents
563563

564-
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
565564
def prepare_latents(
566565
self,
567566
image,
@@ -614,7 +613,6 @@ def prepare_latents(
614613
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
615614
return latents, noise, image_latents, latent_image_ids
616615

617-
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
618616
def prepare_mask_latents(
619617
self,
620618
mask,

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ def __init__(
225225
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
226226
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
227227
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
228-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
229+
self.image_processor = VaeImageProcessor(
230+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
231+
)
229232
self.tokenizer_max_length = (
230233
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
231234
)
@@ -634,7 +637,10 @@ def prepare_latents(
634637
return latents.to(device=device, dtype=dtype), latent_image_ids
635638

636639
image = image.to(device=device, dtype=dtype)
637-
image_latents = self._encode_vae_image(image=image, generator=generator)
640+
if image.shape[1] != self.latent_channels:
641+
image_latents = self._encode_vae_image(image=image, generator=generator)
642+
else:
643+
image_latents = image
638644
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
639645
# expand init_latents for batch_size
640646
additional_image_per_prompt = batch_size // image_latents.shape[0]

src/diffusers/pipelines/flux/pipeline_flux_inpaint.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,13 @@ def __init__(
222222
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
223223
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
224224
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
225-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
226-
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
225+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
226+
self.image_processor = VaeImageProcessor(
227+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
228+
)
227229
self.mask_processor = VaeImageProcessor(
228230
vae_scale_factor=self.vae_scale_factor * 2,
229-
vae_latent_channels=latent_channels,
231+
vae_latent_channels=self.latent_channels,
230232
do_normalize=False,
231233
do_binarize=True,
232234
do_convert_grayscale=True,
@@ -653,7 +655,10 @@ def prepare_latents(
653655
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
654656

655657
image = image.to(device=device, dtype=dtype)
656-
image_latents = self._encode_vae_image(image=image, generator=generator)
658+
if image.shape[1] != self.latent_channels:
659+
image_latents = self._encode_vae_image(image=image, generator=generator)
660+
else:
661+
image_latents = image
657662

658663
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
659664
# expand init_latents for batch_size
@@ -710,7 +715,9 @@ def prepare_mask_latents(
710715
else:
711716
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
712717

713-
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
718+
masked_image_latents = (
719+
masked_image_latents - self.vae.config.shift_factor
720+
) * self.vae.config.scaling_factor
714721

715722
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
716723
if mask.shape[0] < batch_size:

0 commit comments

Comments
 (0)