Skip to content

Commit d2a2981

Browse files
committed
update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite
1 parent 80ac5a7 commit d2a2981

File tree

2 files changed

+89
-31
lines changed

2 files changed

+89
-31
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .hooks import HookRegistry, ModelHook
2222

2323

24-
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
logger = get_logger(__name__) # pylint: disable=invalid-name
2525

2626

2727
class ModuleGroup:
@@ -32,12 +32,16 @@ def __init__(
3232
onload_device: torch.device,
3333
offload_leader: torch.nn.Module,
3434
onload_leader: Optional[torch.nn.Module] = None,
35+
parameters: Optional[List[torch.nn.Parameter]] = None,
36+
buffers: Optional[List[torch.Tensor]] = None,
3537
) -> None:
3638
self.modules = modules
3739
self.offload_device = offload_device
3840
self.onload_device = onload_device
3941
self.offload_leader = offload_leader
4042
self.onload_leader = onload_leader
43+
self.parameters = parameters
44+
self.buffers = buffers
4145

4246

4347
class GroupOffloadingHook(ModelHook):
@@ -64,13 +68,15 @@ def __init__(
6468
stream: Optional[torch.cuda.Stream] = None,
6569
next_group: Optional[ModuleGroup] = None,
6670
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
71+
onload_self: bool = False,
6772
) -> None:
6873
self.group = group
6974
self.offload_on_init = offload_on_init
7075
self.non_blocking = non_blocking
7176
self.stream = stream
7277
self.next_group = next_group
7378
self.cpu_param_dict = cpu_param_dict
79+
self.onload_self = onload_self
7480

7581
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7682
if self.offload_on_init:
@@ -100,9 +106,16 @@ def onload_(self, module: torch.nn.Module) -> None:
100106
with torch.cuda.stream(self.stream):
101107
for group_module in self.next_group.modules:
102108
group_module.to(self.next_group.onload_device, non_blocking=True)
103-
else:
109+
110+
if self.stream is None or self.onload_self:
104111
for group_module in self.group.modules:
105112
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)
113+
if self.group.parameters is not None:
114+
for param in self.group.parameters:
115+
param.data = param.data.to(self.group.onload_device, non_blocking=self.non_blocking)
116+
if self.group.buffers is not None:
117+
for buffer in self.group.buffers:
118+
buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking)
106119

107120
def offload_(self, module: torch.nn.Module) -> None:
108121
if self.group.offload_leader == module:
@@ -113,6 +126,13 @@ def offload_(self, module: torch.nn.Module) -> None:
113126
else:
114127
for group_module in self.group.modules:
115128
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
129+
if self.group.parameters is not None:
130+
for param in self.group.parameters:
131+
param.data = param.data.to(self.group.offload_device, non_blocking=self.non_blocking)
132+
if self.group.buffers is not None:
133+
for buffer in self.group.buffers:
134+
buffer.data = buffer.data.to(self.group.offload_device, non_blocking=self.non_blocking)
135+
116136
# TODO: do we need to sync here because of GPU->CPU transfer?
117137
if self.non_blocking and self.group.offload_device.type == "cpu":
118138
torch.cpu.synchronize()
@@ -128,9 +148,9 @@ def apply_group_offloading(
128148
non_blocking: bool = False,
129149
cuda_stream: bool = False,
130150
) -> None:
131-
# stream = None
132-
# if cuda_stream:
133-
# stream = torch.cuda.Stream()
151+
stream = None
152+
if cuda_stream:
153+
stream = torch.cuda.Stream()
134154
if offload_group_patterns == "modulelist_or_sequential":
135155
if num_blocks_per_group is None:
136156
raise ValueError(
@@ -148,7 +168,7 @@ def apply_group_offloading(
148168
offload_group_patterns = _get_modulelist_or_sequential_group_patterns(module, num_blocks_per_group)
149169

150170
_apply_group_offloading_group_patterns(
151-
module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking
171+
module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking, stream=stream
152172
)
153173

154174

@@ -231,6 +251,7 @@ def _apply_group_offloading_group_patterns(
231251
onload_device: torch.device,
232252
force_offload: bool,
233253
non_blocking: bool,
254+
stream: Optional[torch.cuda.Stream] = None,
234255
) -> None:
235256
r"""
236257
This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules
@@ -269,8 +290,17 @@ def _apply_group_offloading_group_patterns(
269290
non_blocking (`bool`):
270291
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
271292
and data transfer.
293+
stream (`torch.cuda.Stream`, *optional*):
294+
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
295+
for overlapping computation and data transfer.
272296
"""
273297

298+
cpu_param_dict = None
299+
if stream is not None:
300+
for param in module.parameters():
301+
param.data = param.data.cpu().pin_memory()
302+
cpu_param_dict = {param: param.data for param in module.parameters()}
303+
274304
per_group_modules = [[] for _ in range(len(offload_group_patterns))]
275305
per_group_offload_leaders = [None] * len(offload_group_patterns)
276306
per_group_onload_leaders = [None] * len(offload_group_patterns)
@@ -280,20 +310,20 @@ def _apply_group_offloading_group_patterns(
280310
offload_leader_patterns = [pattern[1] for pattern in offload_group_patterns]
281311
onload_leader_patterns = [pattern[2] for pattern in offload_group_patterns]
282312

283-
for name, module in module.named_modules():
284-
if name.count(".") > 1:
313+
for name, submodule in module.named_modules():
314+
if name == "" or name.count(".") > 1:
285315
# We only want the layers that are top-level in the module (encompass all the other submodules)
286316
# for enabling offloading. This method is specifically targeted for diffusers format models,
287317
# so we can ignore submodules.
288318
# TODO(aryan): This is not the case and is just a workaround to make the benchmark code work
289319
# for now. We need to support the arbitrary nesting of modules here.
290320
continue
291-
num_matches = 0
292321

293322
# Check if the module matches any of the offload group patterns
323+
num_matches = 0
294324
for i, pattern in enumerate(group_patterns):
295325
if re.search(pattern, name) is not None:
296-
per_group_modules[i].append(module)
326+
per_group_modules[i].append(submodule)
297327
num_matches += 1
298328

299329
# Check if the module matches any of the offload leader patterns
@@ -303,7 +333,7 @@ def _apply_group_offloading_group_patterns(
303333
raise ValueError(
304334
f"Module {name} matches multiple offload leader patterns. Please ensure that offload leader patterns are mutually exclusive."
305335
)
306-
per_group_offload_leaders[i] = module
336+
per_group_offload_leaders[i] = submodule
307337

308338
# Check if the module matches any of the onload leader patterns
309339
for i, pattern in enumerate(onload_leader_patterns):
@@ -314,16 +344,17 @@ def _apply_group_offloading_group_patterns(
314344
raise ValueError(
315345
f"Module {name} matches multiple onload leader patterns. Please ensure that onload leader patterns are mutually exclusive."
316346
)
317-
per_group_onload_leaders[i] = module
347+
per_group_onload_leaders[i] = submodule
318348

319349
if num_matches == 0:
320-
unmatched_group_modules.append(module)
350+
unmatched_group_modules.append((name, submodule))
321351
elif num_matches > 1:
322352
raise ValueError(
323353
f"Module {name} matches multiple offload group patterns. Please ensure that offloading group patterns are mutually exclusive."
324354
)
325355

326356
# Handle modules that matched patterns
357+
groups = []
327358
for i in range(len(per_group_modules)):
328359
if per_group_offload_leaders[i] is None:
329360
raise ValueError(
@@ -336,21 +367,40 @@ def _apply_group_offloading_group_patterns(
336367
offload_leader=per_group_offload_leaders[i],
337368
onload_leader=per_group_onload_leaders[i],
338369
)
339-
_apply_group_offloading(group, force_offload, non_blocking)
340-
341-
# Handle modules that did not match patterns
342-
for module in unmatched_group_modules:
343-
group = ModuleGroup([module], offload_device, onload_device, offload_leader=module, onload_leader=module)
344-
_apply_group_offloading(group, force_offload, non_blocking)
345-
346-
# TODO(aryan): When you add stream support, this may need to be put in an if-branch
347-
# Always keep parameters and buffers on onload_device
348-
for name, param in module.named_parameters(recurse=False):
349-
if torch.is_tensor(param.data):
350-
param.data = param.data.to(onload_device)
370+
groups.append(group)
371+
372+
for i in range(len(groups)):
373+
next_group = groups[i + 1] if i + 1 < len(groups) and stream is not None else None
374+
should_offload = force_offload or i > 0
375+
_apply_group_offloading(
376+
groups[i], should_offload, non_blocking, stream, next_group, cpu_param_dict, onload_self=False
377+
)
378+
379+
# Ignore parameters/buffers if they're already accounted for in unmatched_group_modules (for example, a nn.Linear
380+
# in the top-level module will also be present in the named_parameters iterator)
381+
parameters = []
382+
for name, parameter in module.named_parameters(recurse=False):
383+
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules):
384+
parameters.append(parameter)
385+
386+
buffers = []
351387
for name, buffer in module.named_buffers(recurse=False):
352-
if torch.is_tensor(buffer.data):
353-
buffer.data = buffer.data.to(onload_device)
388+
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules):
389+
buffers.append(buffer)
390+
391+
unmatched_modules = [module for _, module in unmatched_group_modules]
392+
unmatched_group = ModuleGroup(
393+
unmatched_modules,
394+
offload_device,
395+
onload_device,
396+
offload_leader=module,
397+
onload_leader=None,
398+
parameters=parameters,
399+
buffers=buffers,
400+
)
401+
_apply_group_offloading(
402+
unmatched_group, force_offload, non_blocking, stream, groups[0], cpu_param_dict, onload_self=True
403+
)
354404

355405

356406
def _apply_group_offloading(
@@ -360,9 +410,12 @@ def _apply_group_offloading(
360410
stream: Optional[torch.cuda.Stream] = None,
361411
next_group: Optional[ModuleGroup] = None,
362412
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
413+
onload_self: bool = False,
363414
) -> None:
364415
for module in group.modules:
365-
hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict)
416+
hook = GroupOffloadingHook(
417+
group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict, onload_self
418+
)
366419
registry = HookRegistry.check_if_exists_or_initialize(module)
367420
registry.register_hook(hook, "group_offloading")
368421

@@ -375,11 +428,11 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl
375428
blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together.
376429
"""
377430
group_patterns = []
378-
431+
379432
# We only want the layers that are top-level in the module (encompass all the other submodules)
380433
# for enabling offloading. This method is specifically targeted for diffusers format models,
381434
# so we can ignore everything but the children of this module.
382-
for name, submodule in module.children():
435+
for name, submodule in module.named_children():
383436
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
384437
continue
385438
for i in range(0, len(submodule), num_blocks_per_group):
@@ -389,6 +442,6 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl
389442
onload_leader_pattern = rf"{name}\.{i}\b"
390443
offload_leader_pattern = rf"{name}\.{i + num_modules - 1}\b"
391444
group_patterns.append((pattern, offload_leader_pattern, onload_leader_pattern))
392-
445+
393446
logger.debug(f"Generated group patterns for apply_groupwise_offloading: {group_patterns}")
394447
return group_patterns

src/diffusers/hooks/hooks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ModelHook:
3333
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3434
r"""
3535
Hook that is executed when a model is initialized.
36+
3637
Args:
3738
module (`torch.nn.Module`):
3839
The module attached to this hook.
@@ -42,6 +43,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4243
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4344
r"""
4445
Hook that is executed when a model is deinitalized.
46+
4547
Args:
4648
module (`torch.nn.Module`):
4749
The module attached to this hook.
@@ -53,6 +55,7 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
5355
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
5456
r"""
5557
Hook that is executed just before the forward method of the model.
58+
5659
Args:
5760
module (`torch.nn.Module`):
5861
The module whose forward pass will be executed just after this event.
@@ -69,6 +72,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
6972
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
7073
r"""
7174
Hook that is executed just after the forward method of the model.
75+
7276
Args:
7377
module (`torch.nn.Module`):
7478
The module whose forward pass been executed just before this event.
@@ -82,6 +86,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
8286
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
8387
r"""
8488
Hook that is executed when the hook is detached from a module.
89+
8590
Args:
8691
module (`torch.nn.Module`):
8792
The module detached from this hook.

0 commit comments

Comments
 (0)