Skip to content

Commit 42bc19b

Browse files
committed
rewrite
1 parent 22aff34 commit 42bc19b

File tree

4 files changed

+269
-68
lines changed

4 files changed

+269
-68
lines changed

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
if is_torch_available():
55
from .group_offloading import apply_group_offloading
6+
from .hooks import HookRegistry

src/diffusers/hooks/group_offloading.py

Lines changed: 204 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from contextlib import nullcontext
16-
from typing import Dict, List, Optional
16+
from typing import Dict, List, Optional, Tuple
1717

1818
import torch
1919
from accelerate.utils import send_to_device
@@ -25,6 +25,11 @@
2525
logger = get_logger(__name__) # pylint: disable=invalid-name
2626

2727

28+
_GROUP_OFFLOADING = "group_offloading"
29+
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
30+
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
31+
32+
2833
class ModuleGroup:
2934
def __init__(
3035
self,
@@ -99,6 +104,8 @@ class GroupOffloadingHook(ModelHook):
99104
group is responsible for onloading the current module group.
100105
"""
101106

107+
_is_stateful = False
108+
102109
def __init__(
103110
self,
104111
group: ModuleGroup,
@@ -132,6 +139,85 @@ def post_forward(self, module: torch.nn.Module, output):
132139
return output
133140

134141

142+
class LazyPrefetchGroupOffloadingHook(ModelHook):
143+
_is_stateful = False
144+
145+
def __init__(self):
146+
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
147+
self._layer_execution_tracker_module_names = set()
148+
149+
def initialize_hook(self, module):
150+
for name, submodule in module.named_modules():
151+
if name == "" or not hasattr(submodule, "_diffusers_hook"):
152+
continue
153+
154+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
155+
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
156+
157+
if group_offloading_hook is not None:
158+
159+
def make_execution_order_update_callback(current_name, current_submodule):
160+
def callback():
161+
logger.debug(f"Adding {current_name} to the execution order")
162+
self.execution_order.append((current_name, current_submodule))
163+
164+
return callback
165+
166+
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
167+
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
168+
self._layer_execution_tracker_module_names.add(name)
169+
170+
return module
171+
172+
def post_forward(self, module, output):
173+
num_executed = len(self.execution_order)
174+
execution_order_module_names = {name for name, _ in self.execution_order}
175+
176+
# Check if the two sets are equal
177+
if execution_order_module_names != self._layer_execution_tracker_module_names:
178+
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
179+
logger.warning(
180+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
181+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
182+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
183+
f"{unexecuted_layers=}"
184+
)
185+
186+
base_module_registry = HookRegistry.check_if_exists_or_initialize(module)
187+
registries = [HookRegistry.check_if_exists_or_initialize(submodule) for _, submodule in self.execution_order]
188+
189+
for i in range(num_executed):
190+
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER)
191+
192+
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING)
193+
194+
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
195+
if num_executed > 0:
196+
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
197+
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
198+
base_module_group_offloading_hook.next_group.onload_self = False
199+
200+
for i in range(num_executed - 1):
201+
name1, _ = self.execution_order[i]
202+
name2, _ = self.execution_order[i + 1]
203+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
204+
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
205+
group_offloading_hooks[i].next_group.onload_self = False
206+
207+
return output
208+
209+
210+
class LayerExecutionTrackerHook(ModelHook):
211+
_is_stateful = False
212+
213+
def __init__(self, execution_order_update_callback):
214+
self.execution_order_update_callback = execution_order_update_callback
215+
216+
def pre_forward(self, module, *args, **kwargs):
217+
self.execution_order_update_callback()
218+
return args, kwargs
219+
220+
135221
def apply_group_offloading(
136222
module: torch.nn.Module,
137223
offload_type: str = "block_level",
@@ -156,10 +242,10 @@ def apply_group_offloading(
156242
_apply_group_offloading_block_level(
157243
module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking, stream=stream
158244
)
159-
# elif offload_type == "leaf_level":
160-
# _apply_group_offloading_leaf_level(
161-
# module, offload_device, onload_device, force_offload, non_blocking, stream=stream
162-
# )
245+
elif offload_type == "leaf_level":
246+
_apply_group_offloading_leaf_level(
247+
module, offload_device, onload_device, force_offload, non_blocking, stream=stream
248+
)
163249

164250

165251
def _apply_group_offloading_block_level(
@@ -205,12 +291,13 @@ def _apply_group_offloading_block_level(
205291
unmatched_modules.append((name, submodule))
206292
continue
207293
for i in range(0, len(submodule), num_blocks_per_group):
294+
current_modules = submodule[i : i + num_blocks_per_group]
208295
group = ModuleGroup(
209296
modules=submodule[i : i + num_blocks_per_group],
210297
offload_device=offload_device,
211298
onload_device=onload_device,
212-
offload_leader=submodule[i],
213-
onload_leader=None,
299+
offload_leader=current_modules[-1],
300+
onload_leader=current_modules[0],
214301
non_blocking=non_blocking,
215302
stream=stream,
216303
cpu_param_dict=cpu_param_dict,
@@ -223,7 +310,9 @@ def _apply_group_offloading_block_level(
223310
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
224311
)
225312
should_offload = force_offload or i > 0
226-
_apply_group_offloading(group, should_offload, next_group)
313+
314+
for group_module in group.modules:
315+
_apply_group_offloading_hook(group_module, group, should_offload, next_group)
227316

228317
parameters = []
229318
for name, parameter in module.named_parameters(recurse=False):
@@ -241,50 +330,121 @@ def _apply_group_offloading_block_level(
241330
offload_device=offload_device,
242331
onload_device=onload_device,
243332
offload_leader=module,
244-
onload_leader=None,
333+
onload_leader=module,
334+
parameters=parameters,
335+
buffers=buffers,
336+
non_blocking=False,
337+
stream=None,
338+
cpu_param_dict=None,
339+
onload_self=True,
340+
)
341+
_apply_group_offloading_hook(module, unmatched_group, force_offload, matched_module_groups[0])
342+
343+
344+
def _apply_group_offloading_leaf_level(
345+
module: torch.nn.Module,
346+
offload_device: torch.device,
347+
onload_device: torch.device,
348+
force_offload: bool,
349+
non_blocking: bool,
350+
stream: Optional[torch.cuda.Stream] = None,
351+
) -> None:
352+
r"""
353+
This function applies offloading to groups of leaf modules in a torch.nn.Module.
354+
355+
Args:
356+
module (`torch.nn.Module`):
357+
The module to which group offloading is applied.
358+
offload_device (`torch.device`):
359+
The device to which the group of modules are offloaded. This should typically be the CPU.
360+
onload_device (`torch.device`):
361+
The device to which the group of modules are onloaded.
362+
force_offload (`bool`):
363+
If True, all module groups are offloaded to the offload_device. If False, only layers that match
364+
`offload_group_patterns` are offloaded to the offload_device.
365+
non_blocking (`bool`):
366+
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
367+
and data transfer.
368+
stream (`torch.cuda.Stream`, *optional*):
369+
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
370+
for overlapping computation and data transfer.
371+
"""
372+
373+
cpu_param_dict = None
374+
if stream is not None:
375+
for param in module.parameters():
376+
param.data = param.data.cpu().pin_memory()
377+
cpu_param_dict = {param: param.data for param in module.parameters()}
378+
379+
for submodule in module.modules():
380+
if len(list(submodule.children())) != 0:
381+
continue
382+
group = ModuleGroup(
383+
modules=[submodule],
384+
offload_device=offload_device,
385+
onload_device=onload_device,
386+
offload_leader=submodule,
387+
onload_leader=submodule,
388+
non_blocking=non_blocking,
389+
stream=stream,
390+
cpu_param_dict=cpu_param_dict,
391+
onload_self=True,
392+
)
393+
_apply_group_offloading_hook(submodule, group, True, None)
394+
395+
parameters = []
396+
buffers = []
397+
398+
def gather_non_module_parameters_and_buffers(m: torch.nn.Module):
399+
if len(list(m.children())) == 0:
400+
return
401+
for parameter in m.parameters(recurse=False):
402+
parameters.append(parameter)
403+
for buffer in m.buffers(recurse=False):
404+
buffers.append(buffer)
405+
for submodule in m.children():
406+
gather_non_module_parameters_and_buffers(submodule)
407+
408+
gather_non_module_parameters_and_buffers(module)
409+
unmatched_group = ModuleGroup(
410+
modules=[],
411+
offload_device=offload_device,
412+
onload_device=onload_device,
413+
offload_leader=module,
414+
onload_leader=module,
245415
parameters=parameters,
246416
buffers=buffers,
247417
non_blocking=False,
248418
stream=None,
249419
cpu_param_dict=cpu_param_dict,
250420
onload_self=True,
251421
)
252-
_apply_group_offloading(unmatched_group, force_offload, matched_module_groups[0])
253-
254-
255-
# def _apply_group_offloading_leaf_level(
256-
# module: torch.nn.Module,
257-
# offload_device: torch.device,
258-
# onload_device: torch.device,
259-
# force_offload: bool,
260-
# non_blocking: bool,
261-
# stream: Optional[torch.cuda.Stream] = None,
262-
# ) -> None:
263-
# r"""
264-
# This function applies offloading to groups of leaf modules in a torch.nn.Module.
265-
266-
# Args: # module (`torch.nn.Module`): # The module to which group offloading is applied. # offload_device
267-
(`torch.device`): # The device to which the group of modules are offloaded. This should typically be the CPU. #
268-
onload_device (`torch.device`): # The device to which the group of modules are onloaded. # force_offload (`bool`): # If
269-
True, all module groups are offloaded to the offload_device. If False, only layers that match #
270-
`offload_group_patterns` are offloaded to the offload_device. # non_blocking (`bool`): # If True, offloading and
271-
onloading is done asynchronously. This can be useful for overlapping computation # and data transfer. # stream
272-
(`torch.cuda.Stream`, *optional*): # If provided, offloading and onloading is done asynchronously using the provided
273-
stream. This can be useful # for overlapping computation and data transfer. #"""
274-
275-
# cpu_param_dict = None
276-
# if stream is not None:
277-
# for param in module.parameters():
278-
# param.data = param.data.cpu().pin_memory()
279-
# cpu_param_dict = {param: param.data for param in module.parameters()}
280-
281-
282-
def _apply_group_offloading(
422+
423+
if stream is None:
424+
_apply_group_offloading_hook(module, unmatched_group, force_offload, None)
425+
else:
426+
_apply_lazy_group_offloading_hook(module, unmatched_group, force_offload, None)
427+
428+
429+
def _apply_group_offloading_hook(
430+
module: torch.nn.Module,
431+
group: ModuleGroup,
432+
offload_on_init: bool,
433+
next_group: Optional[ModuleGroup] = None,
434+
) -> None:
435+
hook = GroupOffloadingHook(group, offload_on_init, next_group)
436+
registry = HookRegistry.check_if_exists_or_initialize(module)
437+
registry.register_hook(hook, _GROUP_OFFLOADING)
438+
439+
440+
def _apply_lazy_group_offloading_hook(
441+
module: torch.nn.Module,
283442
group: ModuleGroup,
284443
offload_on_init: bool,
285444
next_group: Optional[ModuleGroup] = None,
286445
) -> None:
287-
for module in group.modules:
288-
hook = GroupOffloadingHook(group, offload_on_init, next_group)
289-
registry = HookRegistry.check_if_exists_or_initialize(module)
290-
registry.register_hook(hook, "group_offloading")
446+
hook = GroupOffloadingHook(group, offload_on_init, next_group)
447+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
448+
registry = HookRegistry.check_if_exists_or_initialize(module)
449+
registry.register_hook(hook, _GROUP_OFFLOADING)
450+
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)

0 commit comments

Comments
 (0)