@@ -281,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
281281
282282 _is_stateful = False
283283
284- def __init__ (
285- self , group : ModuleGroup , next_group : Optional [ModuleGroup ] = None , * , config : GroupOffloadingConfig
286- ) -> None :
284+ def __init__ (self , group : ModuleGroup , * , config : GroupOffloadingConfig ) -> None :
287285 self .group = group
288- self .next_group = next_group
286+ self .next_group : Optional [ ModuleGroup ] = None
289287 self .config = config
290288
291289 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
@@ -609,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
609607 # Apply group offloading hooks to the module groups
610608 for i , group in enumerate (matched_module_groups ):
611609 for group_module in group .modules :
612- _apply_group_offloading_hook (group_module , group , None , config = config )
610+ _apply_group_offloading_hook (group_module , group , config = config )
613611
614612 # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
615613 # when the forward pass of this module is called. This is because the top-level module is not
@@ -638,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
638636 group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
639637 )
640638 if config .stream is None :
641- _apply_group_offloading_hook (module , unmatched_group , None , config = config )
639+ _apply_group_offloading_hook (module , unmatched_group , config = config )
642640 else :
643- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
641+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
644642
645643
646644def _apply_group_offloading_leaf_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
@@ -669,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
669667 onload_self = True ,
670668 group_id = name ,
671669 )
672- _apply_group_offloading_hook (submodule , group , None , config = config )
670+ _apply_group_offloading_hook (submodule , group , config = config )
673671 modules_with_group_offloading .add (name )
674672
675673 # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -716,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
716714 onload_self = True ,
717715 group_id = name ,
718716 )
719- _apply_group_offloading_hook (parent_module , group , None , config = config )
717+ _apply_group_offloading_hook (parent_module , group , config = config )
720718
721719 if config .stream is not None :
722720 # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -738,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
738736 onload_self = True ,
739737 group_id = _GROUP_ID_LAZY_LEAF ,
740738 )
741- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
739+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
742740
743741
744742def _apply_group_offloading_hook (
745743 module : torch .nn .Module ,
746744 group : ModuleGroup ,
747- next_group : Optional [ModuleGroup ] = None ,
748745 * ,
749746 config : GroupOffloadingConfig ,
750747) -> None :
@@ -753,14 +750,13 @@ def _apply_group_offloading_hook(
753750 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
754751 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
755752 if registry .get_hook (_GROUP_OFFLOADING ) is None :
756- hook = GroupOffloadingHook (group , next_group , config = config )
753+ hook = GroupOffloadingHook (group , config = config )
757754 registry .register_hook (hook , _GROUP_OFFLOADING )
758755
759756
760757def _apply_lazy_group_offloading_hook (
761758 module : torch .nn .Module ,
762759 group : ModuleGroup ,
763- next_group : Optional [ModuleGroup ] = None ,
764760 * ,
765761 config : GroupOffloadingConfig ,
766762) -> None :
@@ -769,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
769765 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
770766 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
771767 if registry .get_hook (_GROUP_OFFLOADING ) is None :
772- hook = GroupOffloadingHook (group , next_group , config = config )
768+ hook = GroupOffloadingHook (group , config = config )
773769 registry .register_hook (hook , _GROUP_OFFLOADING )
774770
775771 lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
0 commit comments