Skip to content

Commit b2e838f

Browse files
committed
why still use slightly more memory when less memory do trick
1 parent 8ba2bda commit b2e838f

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -455,41 +455,40 @@ def _apply_group_offloading_leaf_level(
455455
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
456456

457457
# Find closest module parent for each parameter and buffer, and attach group hooks
458-
common_kwargs = {
459-
"modules": [],
460-
"offload_device": offload_device,
461-
"onload_device": onload_device,
462-
"non_blocking": non_blocking,
463-
"stream": stream,
464-
"cpu_param_dict": cpu_param_dict,
465-
"onload_self": True,
466-
}
467-
458+
parent_to_parameters = {}
468459
for name, param in parameters:
469460
parent_name = _find_parent_module_in_module_dict(name, module_dict)
470-
parent_module = module_dict[parent_name]
471-
logger.info(f"TODO: REMOVETHIS Found parameter {name} with parent module {parent_name}")
472-
assert getattr(parent_module, "_diffusers_hook", None) is None
473-
group = ModuleGroup(
474-
offload_leader=parent_module,
475-
onload_leader=parent_module,
476-
parameters=[param],
477-
buffers=None,
478-
**common_kwargs,
479-
)
480-
_apply_group_offloading_hook(parent_module, group, True, None)
461+
if parent_name in parent_to_parameters:
462+
parent_to_parameters[parent_name].append(param)
463+
else:
464+
parent_to_parameters[parent_name] = [param]
481465

466+
parent_to_buffers = {}
482467
for name, buffer in buffers:
483468
parent_name = _find_parent_module_in_module_dict(name, module_dict)
484-
parent_module = module_dict[parent_name]
485-
logger.info(f"TODO: REMOVETHIS Found buffer {name} with parent module {parent_name}")
469+
if parent_name in parent_to_buffers:
470+
parent_to_buffers[parent_name].append(buffer)
471+
else:
472+
parent_to_buffers[parent_name] = [buffer]
473+
474+
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
475+
for name in parent_names:
476+
parameters = parent_to_parameters.get(name, [])
477+
buffers = parent_to_buffers.get(name, [])
478+
parent_module = module_dict[name]
486479
assert getattr(parent_module, "_diffusers_hook", None) is None
487480
group = ModuleGroup(
481+
modules=[],
482+
offload_device=offload_device,
483+
onload_device=onload_device,
488484
offload_leader=parent_module,
489485
onload_leader=parent_module,
490-
parameters=None,
491-
buffers=[buffer],
492-
**common_kwargs,
486+
parameters=parameters,
487+
buffers=buffers,
488+
non_blocking=non_blocking,
489+
stream=stream,
490+
cpu_param_dict=cpu_param_dict,
491+
onload_self=True,
493492
)
494493
_apply_group_offloading_hook(parent_module, group, True, None)
495494

@@ -557,7 +556,6 @@ def _gather_parameters_with_no_group_offloading_parent(
557556
atoms.pop()
558557

559558
if not has_parent_with_group_offloading:
560-
logger.info(f"TODO: REMOVETHIS Found parameter {name} with no parent module with group offloading")
561559
parameters.append((name, parameter))
562560
return parameters
563561

@@ -578,7 +576,6 @@ def _gather_buffers_with_no_group_offloading_parent(
578576
atoms.pop()
579577

580578
if not has_parent_with_group_offloading:
581-
logger.info(f"TODO: REMOVETHIS Found buffer {name} with no parent module with group offloading")
582579
buffers.append((name, buffer))
583580
return buffers
584581

0 commit comments

Comments
 (0)