Skip to content

Commit 514f1d7

Browse files
authored
Merge branch 'main' into flux-quantized-w-lora
2 parents b504f61 + f424b1b commit 514f1d7

21 files changed

+231
-129
lines changed

examples/community/README.md

Lines changed: 45 additions & 20 deletions
Large diffs are not rendered by default.

src/diffusers/hooks/group_offloading.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from contextlib import nullcontext
15+
from contextlib import contextmanager, nullcontext
1616
from typing import Dict, List, Optional, Set, Tuple
1717

1818
import torch
@@ -56,23 +56,58 @@ def __init__(
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
59-
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
59+
low_cpu_mem_usage=False,
6060
onload_self: bool = True,
6161
) -> None:
6262
self.modules = modules
6363
self.offload_device = offload_device
6464
self.onload_device = onload_device
6565
self.offload_leader = offload_leader
6666
self.onload_leader = onload_leader
67-
self.parameters = parameters
68-
self.buffers = buffers
67+
self.parameters = parameters or []
68+
self.buffers = buffers or []
6969
self.non_blocking = non_blocking or stream is not None
7070
self.stream = stream
71-
self.cpu_param_dict = cpu_param_dict
7271
self.onload_self = onload_self
72+
self.low_cpu_mem_usage = low_cpu_mem_usage
7373

74-
if self.stream is not None and self.cpu_param_dict is None:
75-
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
74+
self.cpu_param_dict = self._init_cpu_param_dict()
75+
76+
def _init_cpu_param_dict(self):
77+
cpu_param_dict = {}
78+
if self.stream is None:
79+
return cpu_param_dict
80+
81+
for module in self.modules:
82+
for param in module.parameters():
83+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
84+
for buffer in module.buffers():
85+
cpu_param_dict[buffer] = (
86+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
87+
)
88+
89+
for param in self.parameters:
90+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
91+
92+
for buffer in self.buffers:
93+
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
94+
95+
return cpu_param_dict
96+
97+
@contextmanager
98+
def _pinned_memory_tensors(self):
99+
pinned_dict = {}
100+
try:
101+
for param, tensor in self.cpu_param_dict.items():
102+
if not tensor.is_pinned():
103+
pinned_dict[param] = tensor.pin_memory()
104+
else:
105+
pinned_dict[param] = tensor
106+
107+
yield pinned_dict
108+
109+
finally:
110+
pinned_dict = None
76111

77112
def onload_(self):
78113
r"""Onloads the group of modules to the onload_device."""
@@ -82,15 +117,30 @@ def onload_(self):
82117
self.stream.synchronize()
83118

84119
with context:
85-
for group_module in self.modules:
86-
for param in group_module.parameters():
87-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
88-
for buffer in group_module.buffers():
89-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
90-
if self.parameters is not None:
120+
if self.stream is not None:
121+
with self._pinned_memory_tensors() as pinned_memory:
122+
for group_module in self.modules:
123+
for param in group_module.parameters():
124+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
125+
for buffer in group_module.buffers():
126+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
127+
128+
for param in self.parameters:
129+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
130+
131+
for buffer in self.buffers:
132+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
133+
134+
else:
135+
for group_module in self.modules:
136+
for param in group_module.parameters():
137+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
138+
for buffer in group_module.buffers():
139+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
140+
91141
for param in self.parameters:
92142
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
93-
if self.buffers is not None:
143+
94144
for buffer in self.buffers:
95145
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
96146

@@ -101,21 +151,18 @@ def offload_(self):
101151
for group_module in self.modules:
102152
for param in group_module.parameters():
103153
param.data = self.cpu_param_dict[param]
104-
if self.parameters is not None:
105-
for param in self.parameters:
106-
param.data = self.cpu_param_dict[param]
107-
if self.buffers is not None:
108-
for buffer in self.buffers:
109-
buffer.data = self.cpu_param_dict[buffer]
154+
for param in self.parameters:
155+
param.data = self.cpu_param_dict[param]
156+
for buffer in self.buffers:
157+
buffer.data = self.cpu_param_dict[buffer]
158+
110159
else:
111160
for group_module in self.modules:
112161
group_module.to(self.offload_device, non_blocking=self.non_blocking)
113-
if self.parameters is not None:
114-
for param in self.parameters:
115-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
116-
if self.buffers is not None:
117-
for buffer in self.buffers:
118-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
162+
for param in self.parameters:
163+
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
164+
for buffer in self.buffers:
165+
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
119166

120167

121168
class GroupOffloadingHook(ModelHook):
@@ -284,6 +331,7 @@ def apply_group_offloading(
284331
num_blocks_per_group: Optional[int] = None,
285332
non_blocking: bool = False,
286333
use_stream: bool = False,
334+
low_cpu_mem_usage=False,
287335
) -> None:
288336
r"""
289337
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +413,12 @@ def apply_group_offloading(
365413
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
366414

367415
_apply_group_offloading_block_level(
368-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
416+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
369417
)
370418
elif offload_type == "leaf_level":
371-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
419+
_apply_group_offloading_leaf_level(
420+
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
421+
)
372422
else:
373423
raise ValueError(f"Unsupported offload_type: {offload_type}")
374424

@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
380430
onload_device: torch.device,
381431
non_blocking: bool,
382432
stream: Optional[torch.cuda.Stream] = None,
433+
low_cpu_mem_usage: bool = False,
383434
) -> None:
384435
r"""
385436
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
400451
for overlapping computation and data transfer.
401452
"""
402453

403-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
404-
cpu_param_dict = None
405-
if stream is not None:
406-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
407-
408454
# Create module groups for ModuleList and Sequential blocks
409455
modules_with_group_offloading = set()
410456
unmatched_modules = []
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
425471
onload_leader=current_modules[0],
426472
non_blocking=non_blocking,
427473
stream=stream,
428-
cpu_param_dict=cpu_param_dict,
474+
low_cpu_mem_usage=low_cpu_mem_usage,
429475
onload_self=stream is None,
430476
)
431477
matched_module_groups.append(group)
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
462508
buffers=buffers,
463509
non_blocking=False,
464510
stream=None,
465-
cpu_param_dict=None,
466511
onload_self=True,
467512
)
468513
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
475520
onload_device: torch.device,
476521
non_blocking: bool,
477522
stream: Optional[torch.cuda.Stream] = None,
523+
low_cpu_mem_usage: bool = False,
478524
) -> None:
479525
r"""
480526
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
497543
for overlapping computation and data transfer.
498544
"""
499545

500-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
501-
cpu_param_dict = None
502-
if stream is not None:
503-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
504-
505546
# Create module groups for leaf modules and apply group offloading hooks
506547
modules_with_group_offloading = set()
507548
for name, submodule in module.named_modules():
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
515556
onload_leader=submodule,
516557
non_blocking=non_blocking,
517558
stream=stream,
518-
cpu_param_dict=cpu_param_dict,
559+
low_cpu_mem_usage=low_cpu_mem_usage,
519560
onload_self=True,
520561
)
521562
_apply_group_offloading_hook(submodule, group, None)
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
560601
buffers=buffers,
561602
non_blocking=non_blocking,
562603
stream=stream,
563-
cpu_param_dict=cpu_param_dict,
604+
low_cpu_mem_usage=low_cpu_mem_usage,
564605
onload_self=True,
565606
)
566607
_apply_group_offloading_hook(parent_module, group, None)
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
579620
buffers=None,
580621
non_blocking=False,
581622
stream=None,
582-
cpu_param_dict=None,
623+
low_cpu_mem_usage=low_cpu_mem_usage,
583624
onload_self=True,
584625
)
585626
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
616657
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
617658

618659

619-
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
620-
cpu_param_dict = {}
621-
for param in module.parameters():
622-
param.data = param.data.cpu().pin_memory()
623-
cpu_param_dict[param] = param.data
624-
for buffer in module.buffers():
625-
buffer.data = buffer.data.cpu().pin_memory()
626-
cpu_param_dict[buffer] = buffer.data
627-
return cpu_param_dict
628-
629-
630660
def _gather_parameters_with_no_group_offloading_parent(
631661
module: torch.nn.Module, modules_with_group_offloading: Set[str]
632662
) -> List[torch.nn.Parameter]:

src/diffusers/loaders/textual_inversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,9 @@ def load_textual_inversion(
449449

450450
# 7.5 Offload the model again
451451
if is_model_cpu_offload:
452-
self.enable_model_cpu_offload()
452+
self.enable_model_cpu_offload(device=device)
453453
elif is_sequential_cpu_offload:
454-
self.enable_sequential_cpu_offload()
454+
self.enable_sequential_cpu_offload(device=device)
455455

456456
# / Unsafe Code >
457457

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549+
low_cpu_mem_usage=False,
549550
) -> None:
550551
r"""
551552
Activates group offloading for the current model.
@@ -584,7 +585,14 @@ def enable_group_offload(
584585
f"open an issue at https://github.com/huggingface/diffusers/issues."
585586
)
586587
apply_group_offloading(
587-
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
588+
self,
589+
onload_device,
590+
offload_device,
591+
offload_type,
592+
num_blocks_per_group,
593+
non_blocking,
594+
use_stream,
595+
low_cpu_mem_usage=low_cpu_mem_usage,
588596
)
589597

590598
def save_pretrained(

src/diffusers/models/normalization.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -550,16 +550,6 @@ def forward(self, hidden_states):
550550
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
551551
if self.bias is not None:
552552
hidden_states = hidden_states + self.bias
553-
elif is_torch_version(">=", "2.4"):
554-
if self.weight is not None:
555-
# convert into half-precision if necessary
556-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
557-
hidden_states = hidden_states.to(self.weight.dtype)
558-
hidden_states = nn.functional.rms_norm(
559-
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
560-
)
561-
if self.bias is not None:
562-
hidden_states = hidden_states + self.bias
563553
else:
564554
input_dtype = hidden_states.dtype
565555
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)

src/diffusers/utils/testing_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,21 @@ def require_torch_multi_gpu(test_case):
320320
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
321321

322322

323+
def require_torch_multi_accelerator(test_case):
324+
"""
325+
Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
326+
without multiple hardware accelerators.
327+
"""
328+
if not is_torch_available():
329+
return unittest.skip("test requires PyTorch")(test_case)
330+
331+
import torch
332+
333+
return unittest.skipUnless(
334+
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
335+
)(test_case)
336+
337+
323338
def require_torch_accelerator_with_fp16(test_case):
324339
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
325340
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -354,6 +369,31 @@ def require_big_gpu_with_torch_cuda(test_case):
354369
)(test_case)
355370

356371

372+
def require_big_accelerator(test_case):
373+
"""
374+
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
375+
Flux, SD3, Cog, etc.
376+
"""
377+
if not is_torch_available():
378+
return unittest.skip("test requires PyTorch")(test_case)
379+
380+
import torch
381+
382+
if not (torch.cuda.is_available() or torch.xpu.is_available()):
383+
return unittest.skip("test requires PyTorch CUDA")(test_case)
384+
385+
if torch.xpu.is_available():
386+
device_properties = torch.xpu.get_device_properties(0)
387+
else:
388+
device_properties = torch.cuda.get_device_properties(0)
389+
390+
total_memory = device_properties.total_memory / (1024**3)
391+
return unittest.skipUnless(
392+
total_memory >= BIG_GPU_MEMORY,
393+
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
394+
)(test_case)
395+
396+
357397
def require_torch_accelerator_with_training(test_case):
358398
"""Decorator marking a test that requires an accelerator with support for training."""
359399
return unittest.skipUnless(

tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x
124124
return model
125125

126126
def get_generator(self, seed=0):
127-
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
127+
generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
128128
if torch_device != "mps":
129129
return torch.Generator(device=generator_device).manual_seed(seed)
130130
return torch.manual_seed(seed)

0 commit comments

Comments
 (0)