3737 torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
3838 torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
3939 torch .nn .Linear ,
40- torch .nn .LayerNorm , torch .nn .GroupNorm ,
40+ # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
41+ # because of double invocation of the same norm layer in CogVideoXLayerNorm
4142)
4243# fmt: on
4344
@@ -120,15 +121,13 @@ class GroupOffloadingHook(ModelHook):
120121 def __init__ (
121122 self ,
122123 group : ModuleGroup ,
123- offload_on_init : bool = True ,
124124 next_group : Optional [ModuleGroup ] = None ,
125125 ) -> None :
126126 self .group = group
127- self .offload_on_init = offload_on_init
128127 self .next_group = next_group
129128
130129 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
131- if self .offload_on_init and self . group .offload_leader == module :
130+ if self .group .offload_leader == module :
132131 self .group .offload_ ()
133132 return module
134133
@@ -262,14 +261,78 @@ def pre_forward(self, module, *args, **kwargs):
262261
263262def apply_group_offloading (
264263 module : torch .nn .Module ,
264+ onload_device : torch .device ,
265+ offload_device : torch .device = torch .device ("cpu" ),
265266 offload_type : str = "block_level" ,
266267 num_blocks_per_group : Optional [int ] = None ,
267- offload_device : torch .device = torch .device ("cpu" ),
268- onload_device : torch .device = torch .device ("cuda" ),
269- force_offload : bool = True ,
270268 non_blocking : bool = False ,
271269 use_stream : bool = False ,
272270) -> None :
271+ r"""
272+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
273+ where it is beneficial, we need to first provide some context on how other supported offloading methods work.
274+
275+ Typically, offloading is done at two levels:
276+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
277+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
278+ when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
279+ but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
280+ the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
281+ pass.
282+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
283+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
284+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
285+ memory, but can be slower due to the excessive number of device synchronizations.
286+
287+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
288+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than module-level
289+ offloading. It is also faster than leaf-level offloading, as the number of device synchronizations is reduced.
290+
291+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
292+ overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching
293+ with streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the
294+ current layer is being executed - this increases the memory requirements slightly. Note that this implementation
295+ also supports leaf-level offloading but can be made much faster when using streams.
296+
297+ Args:
298+ module (`torch.nn.Module`):
299+ The module to which group offloading is applied.
300+ onload_device (`torch.device`):
301+ The device to which the group of modules are onloaded.
302+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
303+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
304+ offload_type (`str`, defaults to "block_level"):
305+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
306+ "block_level".
307+ num_blocks_per_group (`int`, *optional*):
308+ The number of blocks per group when using offload_type="block_level". This is required when using
309+ offload_type="block_level".
310+ non_blocking (`bool`, defaults to `False`):
311+ If True, offloading and onloading is done with non-blocking data transfer.
312+ use_stream (`bool`, defaults to `False`):
313+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
314+ overlapping computation and data transfer.
315+
316+ Example:
317+ ```python
318+ >>> from diffusers import CogVideoXTransformer3DModel
319+ >>> from diffusers.hooks import apply_group_offloading
320+
321+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
322+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
323+ ... )
324+
325+ >>> apply_group_offloading(
326+ ... transformer,
327+ ... onload_device=torch.device("cuda"),
328+ ... offload_device=torch.device("cpu"),
329+ ... offload_type="block_level",
330+ ... num_blocks_per_group=2,
331+ ... use_stream=True,
332+ ... )
333+ ```
334+ """
335+
273336 stream = None
274337 if use_stream :
275338 if torch .cuda .is_available ():
@@ -279,15 +342,13 @@ def apply_group_offloading(
279342
280343 if offload_type == "block_level" :
281344 if num_blocks_per_group is None :
282- raise ValueError ("num_blocks_per_group must be provided when using offload_group_patterns ='block_level'." )
345+ raise ValueError ("num_blocks_per_group must be provided when using offload_type ='block_level'." )
283346
284347 _apply_group_offloading_block_level (
285- module , num_blocks_per_group , offload_device , onload_device , force_offload , non_blocking , stream = stream
348+ module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream
286349 )
287350 elif offload_type == "leaf_level" :
288- _apply_group_offloading_leaf_level (
289- module , offload_device , onload_device , force_offload , non_blocking , stream = stream
290- )
351+ _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
291352 else :
292353 raise ValueError (f"Unsupported offload_type: { offload_type } " )
293354
@@ -297,7 +358,6 @@ def _apply_group_offloading_block_level(
297358 num_blocks_per_group : int ,
298359 offload_device : torch .device ,
299360 onload_device : torch .device ,
300- force_offload : bool ,
301361 non_blocking : bool ,
302362 stream : Optional [torch .cuda .Stream ] = None ,
303363) -> None :
@@ -312,9 +372,6 @@ def _apply_group_offloading_block_level(
312372 The device to which the group of modules are offloaded. This should typically be the CPU.
313373 onload_device (`torch.device`):
314374 The device to which the group of modules are onloaded.
315- force_offload (`bool`):
316- If True, all module groups are offloaded to the offload_device. If False, only layers that match
317- `offload_group_patterns` are offloaded to the offload_device.
318375 non_blocking (`bool`):
319376 If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
320377 and data transfer.
@@ -362,10 +419,9 @@ def _apply_group_offloading_block_level(
362419 next_group = (
363420 matched_module_groups [i + 1 ] if i + 1 < len (matched_module_groups ) and stream is not None else None
364421 )
365- should_offload = force_offload or i > 0
366422
367423 for group_module in group .modules :
368- _apply_group_offloading_hook (group_module , group , should_offload , next_group )
424+ _apply_group_offloading_hook (group_module , group , next_group )
369425
370426 # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
371427 # when the forward pass of this module is called. This is because the top-level module is not
@@ -392,14 +448,13 @@ def _apply_group_offloading_block_level(
392448 onload_self = True ,
393449 )
394450 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
395- _apply_group_offloading_hook (module , unmatched_group , force_offload , next_group )
451+ _apply_group_offloading_hook (module , unmatched_group , next_group )
396452
397453
398454def _apply_group_offloading_leaf_level (
399455 module : torch .nn .Module ,
400456 offload_device : torch .device ,
401457 onload_device : torch .device ,
402- force_offload : bool ,
403458 non_blocking : bool ,
404459 stream : Optional [torch .cuda .Stream ] = None ,
405460) -> None :
@@ -416,9 +471,6 @@ def _apply_group_offloading_leaf_level(
416471 The device to which the group of modules are offloaded. This should typically be the CPU.
417472 onload_device (`torch.device`):
418473 The device to which the group of modules are onloaded.
419- force_offload (`bool`):
420- If True, all module groups are offloaded to the offload_device. If False, only layers that match
421- `offload_group_patterns` are offloaded to the offload_device.
422474 non_blocking (`bool`):
423475 If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
424476 and data transfer.
@@ -450,7 +502,7 @@ def _apply_group_offloading_leaf_level(
450502 cpu_param_dict = cpu_param_dict ,
451503 onload_self = True ,
452504 )
453- _apply_group_offloading_hook (submodule , group , True , None )
505+ _apply_group_offloading_hook (submodule , group , None )
454506 modules_with_group_offloading .add (name )
455507
456508 # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -495,7 +547,7 @@ def _apply_group_offloading_leaf_level(
495547 cpu_param_dict = cpu_param_dict ,
496548 onload_self = True ,
497549 )
498- _apply_group_offloading_hook (parent_module , group , True , None )
550+ _apply_group_offloading_hook (parent_module , group , None )
499551
500552 # This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module
501553 unmatched_group = ModuleGroup (
@@ -516,38 +568,36 @@ def _apply_group_offloading_leaf_level(
516568 # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
517569 # execution order and apply prefetching in the correct order.
518570 if stream is None :
519- _apply_group_offloading_hook (module , unmatched_group , force_offload , None )
571+ _apply_group_offloading_hook (module , unmatched_group , None )
520572 else :
521- _apply_lazy_group_offloading_hook (module , unmatched_group , force_offload , None )
573+ _apply_lazy_group_offloading_hook (module , unmatched_group , None )
522574
523575
524576def _apply_group_offloading_hook (
525577 module : torch .nn .Module ,
526578 group : ModuleGroup ,
527- offload_on_init : bool ,
528579 next_group : Optional [ModuleGroup ] = None ,
529580) -> None :
530581 registry = HookRegistry .check_if_exists_or_initialize (module )
531582
532583 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
533584 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
534585 if registry .get_hook (_GROUP_OFFLOADING ) is None :
535- hook = GroupOffloadingHook (group , offload_on_init , next_group )
586+ hook = GroupOffloadingHook (group , next_group )
536587 registry .register_hook (hook , _GROUP_OFFLOADING )
537588
538589
539590def _apply_lazy_group_offloading_hook (
540591 module : torch .nn .Module ,
541592 group : ModuleGroup ,
542- offload_on_init : bool ,
543593 next_group : Optional [ModuleGroup ] = None ,
544594) -> None :
545595 registry = HookRegistry .check_if_exists_or_initialize (module )
546596
547597 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
548598 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
549599 if registry .get_hook (_GROUP_OFFLOADING ) is None :
550- hook = GroupOffloadingHook (group , offload_on_init , next_group )
600+ hook = GroupOffloadingHook (group , next_group )
551601 registry .register_hook (hook , _GROUP_OFFLOADING )
552602
553603 lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
@@ -561,14 +611,12 @@ def _gather_parameters_with_no_group_offloading_parent(
561611 for name , parameter in module .named_parameters ():
562612 has_parent_with_group_offloading = False
563613 atoms = name .split ("." )
564-
565614 while len (atoms ) > 0 :
566615 parent_name = "." .join (atoms )
567616 if parent_name in modules_with_group_offloading :
568617 has_parent_with_group_offloading = True
569618 break
570619 atoms .pop ()
571-
572620 if not has_parent_with_group_offloading :
573621 parameters .append ((name , parameter ))
574622 return parameters
@@ -581,14 +629,12 @@ def _gather_buffers_with_no_group_offloading_parent(
581629 for name , buffer in module .named_buffers ():
582630 has_parent_with_group_offloading = False
583631 atoms = name .split ("." )
584-
585632 while len (atoms ) > 0 :
586633 parent_name = "." .join (atoms )
587634 if parent_name in modules_with_group_offloading :
588635 has_parent_with_group_offloading = True
589636 break
590637 atoms .pop ()
591-
592638 if not has_parent_with_group_offloading :
593639 buffers .append ((name , buffer ))
594640 return buffers
0 commit comments