Skip to content

Commit 8ba2bda

Browse files
committed
why use more memory when less memory do trick
1 parent 073d4bc commit 8ba2bda

File tree

1 file changed

+111
-29
lines changed

1 file changed

+111
-29
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from contextlib import nullcontext
16-
from typing import Dict, List, Optional, Tuple
16+
from typing import Dict, List, Optional, Set, Tuple
1717

1818
import torch
1919
from accelerate.utils import send_to_device
@@ -284,6 +284,8 @@ def apply_group_offloading(
284284
_apply_group_offloading_leaf_level(
285285
module, offload_device, onload_device, force_offload, non_blocking, stream=stream
286286
)
287+
else:
288+
raise ValueError(f"Unsupported offload_type: {offload_type}")
287289

288290

289291
def _apply_group_offloading_block_level(
@@ -325,12 +327,15 @@ def _apply_group_offloading_block_level(
325327
cpu_param_dict = {param: param.data for param in module.parameters()}
326328

327329
# Create module groups for ModuleList and Sequential blocks
330+
modules_with_group_offloading = set()
328331
unmatched_modules = []
329332
matched_module_groups = []
330333
for name, submodule in module.named_children():
331334
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
332335
unmatched_modules.append((name, submodule))
336+
modules_with_group_offloading.add(name)
333337
continue
338+
334339
for i in range(0, len(submodule), num_blocks_per_group):
335340
current_modules = submodule[i : i + num_blocks_per_group]
336341
group = ModuleGroup(
@@ -345,6 +350,8 @@ def _apply_group_offloading_block_level(
345350
onload_self=stream is None,
346351
)
347352
matched_module_groups.append(group)
353+
for j in range(i, i + len(current_modules)):
354+
modules_with_group_offloading.add(f"{name}.{j}")
348355

349356
# Apply group offloading hooks to the module groups
350357
for i, group in enumerate(matched_module_groups):
@@ -359,15 +366,10 @@ def _apply_group_offloading_block_level(
359366
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
360367
# when the forward pass of this module is called. This is because the top-level module is not
361368
# part of any group (as doing so would lead to no VRAM savings).
362-
parameters = []
363-
for name, parameter in module.named_parameters(recurse=False):
364-
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules):
365-
parameters.append(parameter)
366-
367-
buffers = []
368-
for name, buffer in module.named_buffers(recurse=False):
369-
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules):
370-
buffers.append(buffer)
369+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
370+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
371+
parameters = [param for _, param in parameters]
372+
buffers = [buffer for _, buffer in buffers]
371373

372374
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
373375
# device when the forward pass is called.
@@ -428,7 +430,8 @@ def _apply_group_offloading_leaf_level(
428430
cpu_param_dict = {param: param.data for param in module.parameters()}
429431

430432
# Create module groups for leaf modules and apply group offloading hooks
431-
for submodule in module.modules():
433+
modules_with_group_offloading = set()
434+
for name, submodule in module.named_modules():
432435
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
433436
continue
434437
group = ModuleGroup(
@@ -443,38 +446,65 @@ def _apply_group_offloading_leaf_level(
443446
onload_self=True,
444447
)
445448
_apply_group_offloading_hook(submodule, group, True, None)
449+
modules_with_group_offloading.add(name)
446450

447451
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
448452
# of the module is called
449-
parameters = []
450-
buffers = []
451453
module_dict = dict(module.named_modules())
454+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
455+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
456+
457+
# Find closest module parent for each parameter and buffer, and attach group hooks
458+
common_kwargs = {
459+
"modules": [],
460+
"offload_device": offload_device,
461+
"onload_device": onload_device,
462+
"non_blocking": non_blocking,
463+
"stream": stream,
464+
"cpu_param_dict": cpu_param_dict,
465+
"onload_self": True,
466+
}
467+
468+
for name, param in parameters:
469+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
470+
parent_module = module_dict[parent_name]
471+
logger.info(f"TODO: REMOVETHIS Found parameter {name} with parent module {parent_name}")
472+
assert getattr(parent_module, "_diffusers_hook", None) is None
473+
group = ModuleGroup(
474+
offload_leader=parent_module,
475+
onload_leader=parent_module,
476+
parameters=[param],
477+
buffers=None,
478+
**common_kwargs,
479+
)
480+
_apply_group_offloading_hook(parent_module, group, True, None)
452481

453-
for name, parameter in module.named_parameters():
454-
atoms = name.split(".")
455-
parent_name = ".".join(atoms[:-1])
456-
if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS):
457-
continue
458-
parameters.append(parameter)
459-
460-
for name, buffer in module.named_buffers():
461-
atoms = name.split(".")
462-
parent_name = ".".join(atoms[:-1])
463-
if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS):
464-
continue
465-
buffers.append(buffer)
482+
for name, buffer in buffers:
483+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
484+
parent_module = module_dict[parent_name]
485+
logger.info(f"TODO: REMOVETHIS Found buffer {name} with parent module {parent_name}")
486+
assert getattr(parent_module, "_diffusers_hook", None) is None
487+
group = ModuleGroup(
488+
offload_leader=parent_module,
489+
onload_leader=parent_module,
490+
parameters=None,
491+
buffers=[buffer],
492+
**common_kwargs,
493+
)
494+
_apply_group_offloading_hook(parent_module, group, True, None)
466495

496+
# This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module
467497
unmatched_group = ModuleGroup(
468498
modules=[],
469499
offload_device=offload_device,
470500
onload_device=onload_device,
471501
offload_leader=module,
472502
onload_leader=module,
473-
parameters=parameters,
474-
buffers=buffers,
503+
parameters=None,
504+
buffers=None,
475505
non_blocking=False,
476506
stream=None,
477-
cpu_param_dict=cpu_param_dict,
507+
cpu_param_dict=None,
478508
onload_self=True,
479509
)
480510

@@ -509,3 +539,55 @@ def _apply_lazy_group_offloading_hook(
509539
registry = HookRegistry.check_if_exists_or_initialize(module)
510540
registry.register_hook(hook, _GROUP_OFFLOADING)
511541
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
542+
543+
544+
def _gather_parameters_with_no_group_offloading_parent(
545+
module: torch.nn.Module, modules_with_group_offloading: Set[str]
546+
) -> List[torch.nn.Parameter]:
547+
parameters = []
548+
for name, parameter in module.named_parameters():
549+
has_parent_with_group_offloading = False
550+
atoms = name.split(".")
551+
552+
while len(atoms) > 0:
553+
parent_name = ".".join(atoms)
554+
if parent_name in modules_with_group_offloading:
555+
has_parent_with_group_offloading = True
556+
break
557+
atoms.pop()
558+
559+
if not has_parent_with_group_offloading:
560+
logger.info(f"TODO: REMOVETHIS Found parameter {name} with no parent module with group offloading")
561+
parameters.append((name, parameter))
562+
return parameters
563+
564+
565+
def _gather_buffers_with_no_group_offloading_parent(
566+
module: torch.nn.Module, modules_with_group_offloading: Set[str]
567+
) -> List[torch.Tensor]:
568+
buffers = []
569+
for name, buffer in module.named_buffers():
570+
has_parent_with_group_offloading = False
571+
atoms = name.split(".")
572+
573+
while len(atoms) > 0:
574+
parent_name = ".".join(atoms)
575+
if parent_name in modules_with_group_offloading:
576+
has_parent_with_group_offloading = True
577+
break
578+
atoms.pop()
579+
580+
if not has_parent_with_group_offloading:
581+
logger.info(f"TODO: REMOVETHIS Found buffer {name} with no parent module with group offloading")
582+
buffers.append((name, buffer))
583+
return buffers
584+
585+
586+
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
587+
atoms = name.split(".")
588+
while len(atoms) > 0:
589+
parent_name = ".".join(atoms)
590+
if parent_name in module_dict:
591+
return parent_name
592+
atoms.pop()
593+
return ""

0 commit comments

Comments
 (0)