@@ -376,9 +376,10 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
376376
377377 _is_stateful = False
378378
379- def __init__ (self ):
379+ def __init__ (self , pin_groups : Optional [ Union [ str , Callable ]] = None ):
380380 self .execution_order : List [Tuple [str , torch .nn .Module ]] = []
381381 self ._layer_execution_tracker_module_names = set ()
382+ self .pin_groups = pin_groups
382383
383384 def initialize_hook (self , module ):
384385 def make_execution_order_update_callback (current_name , current_submodule ):
@@ -460,8 +461,7 @@ def post_forward(self, module, output):
460461 group_offloading_hooks [i ].next_group = group_offloading_hooks [i + 1 ].group
461462 group_offloading_hooks [i ].next_group .onload_self = False
462463
463- pin_groups = getattr (base_module_registry , "_group_offload_pin_groups" , None )
464- if pin_groups is not None and num_executed > 0 :
464+ if self .pin_groups is not None and num_executed > 0 :
465465 param_exec_info = []
466466 for idx , ((name , submodule ), hook ) in enumerate (zip (self .execution_order , group_offloading_hooks )):
467467 if hook is None :
@@ -473,22 +473,22 @@ def post_forward(self, module, output):
473473 num_param_modules = len (param_exec_info )
474474 if num_param_modules > 0 :
475475 pinned_indices = set ()
476- if isinstance (pin_groups , str ):
477- if pin_groups == "all" :
476+ if isinstance (self . pin_groups , str ):
477+ if self . pin_groups == "all" :
478478 pinned_indices = set (range (num_param_modules ))
479- elif pin_groups == "first_last" :
479+ elif self . pin_groups == "first_last" :
480480 pinned_indices .add (0 )
481481 pinned_indices .add (num_param_modules - 1 )
482- elif callable (pin_groups ):
482+ elif callable (self . pin_groups ):
483483 for idx , (name , submodule , _ ) in enumerate (param_exec_info ):
484484 should_pin = False
485485 try :
486- should_pin = bool (pin_groups (submodule ))
486+ should_pin = bool (self . pin_groups (submodule ))
487487 except TypeError :
488488 try :
489- should_pin = bool (pin_groups (name , submodule ))
489+ should_pin = bool (self . pin_groups (name , submodule ))
490490 except TypeError :
491- should_pin = bool (pin_groups (name , submodule , idx ))
491+ should_pin = bool (self . pin_groups (name , submodule , idx ))
492492 if should_pin :
493493 pinned_indices .add (idx )
494494
@@ -651,8 +651,6 @@ def apply_group_offloading(
651651 pin_groups = normalized_pin_groups
652652
653653 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
654- registry = HookRegistry .check_if_exists_or_initialize (module )
655- registry ._group_offload_pin_groups = pin_groups
656654
657655 config = GroupOffloadingConfig (
658656 onload_device = onload_device ,
@@ -671,9 +669,6 @@ def apply_group_offloading(
671669
672670
673671def _apply_group_offloading (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
674- registry = HookRegistry .check_if_exists_or_initialize (module )
675- registry ._group_offload_pin_groups = config .pin_groups
676-
677672 if config .offload_type == GroupOffloadingType .BLOCK_LEVEL :
678673 _apply_group_offloading_block_level (module , config )
679674 elif config .offload_type == GroupOffloadingType .LEAF_LEVEL :
@@ -966,8 +961,8 @@ def _apply_lazy_group_offloading_hook(
966961 if registry .get_hook (_GROUP_OFFLOADING ) is None :
967962 hook = GroupOffloadingHook (group , config = config )
968963 registry .register_hook (hook , _GROUP_OFFLOADING )
969-
970- lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
964+
965+ lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook (pin_groups = config . pin_groups )
971966 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
972967
973968
0 commit comments