@@ -60,6 +60,8 @@ class GroupOffloadingConfig:
6060 offload_to_disk_path : Optional [str ] = None
6161 stream : Optional [Union [torch .cuda .Stream , torch .Stream ]] = None
6262 block_modules : Optional [List [str ]] = None
63+ exclude_kwargs : Optional [List [str ]] = None
64+ module_prefix : Optional [str ] = ""
6365 pin_groups : Optional [Union [str , Callable ]] = None
6466
6567
@@ -156,7 +158,7 @@ def _pinned_memory_tensors(self):
156158 finally :
157159 pinned_dict = None
158160
159- def _transfer_tensor_to_device (self , tensor , source_tensor ):
161+ def _transfer_tensor_to_device (self , tensor , source_tensor , default_stream = None ):
160162 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
161163 if self .record_stream :
162164 tensor .data .record_stream (self ._torch_accelerator_module .current_stream ())
@@ -211,6 +213,7 @@ def _onload_from_memory(self):
211213 self .stream .synchronize ()
212214
213215 context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
216+ default_stream = self ._torch_accelerator_module .current_stream () if self .stream is not None else None
214217 with context :
215218 if self .stream is not None :
216219 with self ._pinned_memory_tensors () as pinned_memory :
@@ -308,13 +311,16 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
308311 self .next_group .onload_ ()
309312
310313 should_synchronize = (
311- not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
314+ not self .group .onload_self
315+ and self .group .stream is not None
316+ and not should_onload_next_group
317+ and not self .group .record_stream
312318 )
313319 if should_synchronize :
314320 self .group .stream .synchronize ()
315321
316322 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
317- kwargs = send_to_device ( kwargs , self .group . onload_device , non_blocking = self . group . non_blocking )
323+ kwargs = self ._send_kwargs_to_device ( kwargs )
318324 return args , kwargs
319325
320326 # If the current module is the onload_leader of the group, we onload the group if it is supposed
@@ -329,7 +335,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
329335 self .next_group .onload_ ()
330336
331337 should_synchronize = (
332- not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
338+ not self .group .onload_self
339+ and self .group .stream is not None
340+ and not should_onload_next_group
341+ and not self .group .record_stream
333342 )
334343 if should_synchronize :
335344 # If this group didn't onload itself, it means it was asynchronously onloaded by the
@@ -341,7 +350,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
341350 self .group .stream .synchronize ()
342351
343352 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
344- kwargs = send_to_device ( kwargs , self .group . onload_device , non_blocking = self . group . non_blocking )
353+ kwargs = self ._send_kwargs_to_device ( kwargs )
345354 return args , kwargs
346355
347356 def post_forward (self , module : torch .nn .Module , output ):
@@ -352,6 +361,28 @@ def post_forward(self, module: torch.nn.Module, output):
352361 self .group .offload_ ()
353362 return output
354363
364+ def _is_group_on_device (self ) -> bool :
365+ tensors = []
366+ for group_module in self .group .modules :
367+ tensors .extend (list (group_module .parameters ()))
368+ tensors .extend (list (group_module .buffers ()))
369+ tensors .extend (self .group .parameters )
370+ tensors .extend (self .group .buffers )
371+
372+ return len (tensors ) > 0 and all (t .device == self .group .onload_device for t in tensors )
373+
374+ def _send_kwargs_to_device (self , kwargs ):
375+ exclude_kwargs = self .config .exclude_kwargs or []
376+ if exclude_kwargs :
377+ moved_kwargs = send_to_device (
378+ {k : v for k , v in kwargs .items () if k not in exclude_kwargs },
379+ self .group .onload_device ,
380+ non_blocking = self .group .non_blocking ,
381+ )
382+ kwargs .update (moved_kwargs )
383+ return kwargs
384+ return send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
385+
355386 def _is_group_on_device (self ) -> bool :
356387 tensors = []
357388 for group_module in self .group .modules :
@@ -524,6 +555,17 @@ def pre_forward(self, module, *args, **kwargs):
524555 return args , kwargs
525556
526557
558+ def _normalize_pin_groups (pin_groups : Optional [Union [str , Callable ]]) -> Optional [Union [str , Callable ]]:
559+ if isinstance (pin_groups , str ):
560+ normalized_pin_groups = pin_groups .lower ()
561+ if normalized_pin_groups not in {"first_last" , "all" }:
562+ raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
563+ return normalized_pin_groups
564+ if pin_groups is not None and not callable (pin_groups ):
565+ raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
566+ return pin_groups
567+
568+
527569def apply_group_offloading (
528570 module : torch .nn .Module ,
529571 onload_device : Union [str , torch .device ],
@@ -536,6 +578,7 @@ def apply_group_offloading(
536578 low_cpu_mem_usage : bool = False ,
537579 offload_to_disk_path : Optional [str ] = None ,
538580 block_modules : Optional [List [str ]] = None ,
581+ exclude_kwargs : Optional [List [str ]] = None ,
539582 pin_groups : Optional [Union [str , Callable ]] = None ,
540583) -> None :
541584 r"""
@@ -595,11 +638,15 @@ def apply_group_offloading(
595638 option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
596639 the CPU memory is a bottleneck but may counteract the benefits of using streams.
597640 block_modules (`List[str]`, *optional*):
598- List of module names that should be treated as blocks for offloading. If provided, only these modules
599- will be considered for block-level offloading. If not provided, the default block detection logic will be used.
641+ List of module names that should be treated as blocks for offloading. If provided, only these modules will
642+ be considered for block-level offloading. If not provided, the default block detection logic will be used.
643+ exclude_kwargs (`List[str]`, *optional*):
644+ List of kwarg keys that should not be processed by `send_to_device`. This is useful for mutable state like
645+ caching lists that need to maintain their object identity across forward passes. If not provided, will be
646+ inferred from the module's `_skip_keys` attribute if it exists.
600647 pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
601- Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
602- and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
648+ Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first and
649+ last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
603650 receives a module (and optionally the module name and index) and returns `True` to pin that group.
604651
605652 Example:
@@ -640,19 +687,14 @@ def apply_group_offloading(
640687 if offload_type == GroupOffloadingType .BLOCK_LEVEL and num_blocks_per_group is None :
641688 raise ValueError ("`num_blocks_per_group` must be provided when using `offload_type='block_level'." )
642689
643- normalized_pin_groups = pin_groups
644- if isinstance (pin_groups , str ):
645- normalized_pin_groups = pin_groups .lower ()
646- if normalized_pin_groups not in {"first_last" , "all" }:
647- raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
648- elif pin_groups is not None and not callable (pin_groups ):
649- raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
690+ pin_groups = _normalize_pin_groups (pin_groups )
691+ _raise_error_if_accelerate_model_or_sequential_hook_present (module )
650692
651- pin_groups = normalized_pin_groups
693+ if block_modules is None :
694+ block_modules = getattr (module , "_group_offload_block_modules" , None )
652695
653- _raise_error_if_accelerate_model_or_sequential_hook_present (module )
654- registry = HookRegistry .check_if_exists_or_initialize (module )
655- registry ._group_offload_pin_groups = pin_groups
696+ if exclude_kwargs is None :
697+ exclude_kwargs = getattr (module , "_skip_keys" , None )
656698
657699 config = GroupOffloadingConfig (
658700 onload_device = onload_device ,
@@ -665,6 +707,8 @@ def apply_group_offloading(
665707 low_cpu_mem_usage = low_cpu_mem_usage ,
666708 offload_to_disk_path = offload_to_disk_path ,
667709 block_modules = block_modules ,
710+ exclude_kwargs = exclude_kwargs ,
711+ module_prefix = "" ,
668712 pin_groups = pin_groups ,
669713 )
670714 _apply_group_offloading (module , config )
@@ -706,7 +750,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
706750
707751 for name , submodule in module .named_children ():
708752 # Check if this is an explicitly defined block module
709- if name in block_modules :
753+ if block_modules and name in block_modules :
710754 # Apply block offloading to the specified submodule
711755 _apply_block_offloading_to_submodule (
712756 submodule , name , config , modules_with_group_offloading , matched_module_groups
@@ -802,7 +846,7 @@ def _apply_block_offloading_to_submodule(
802846 if len (current_modules ) == 0 :
803847 continue
804848
805- group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
849+ group_id = f"{ config . module_prefix } { name } _{ i } _{ i + len (current_modules ) - 1 } "
806850 group = ModuleGroup (
807851 modules = current_modules ,
808852 offload_device = config .offload_device ,
@@ -834,7 +878,7 @@ def _apply_block_offloading_to_submodule(
834878 record_stream = config .record_stream ,
835879 low_cpu_mem_usage = config .low_cpu_mem_usage ,
836880 onload_self = True ,
837- group_id = name ,
881+ group_id = f" { config . module_prefix } { name } " ,
838882 )
839883 matched_module_groups .append (group )
840884 modules_with_group_offloading .add (name )
@@ -864,7 +908,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
864908 record_stream = config .record_stream ,
865909 low_cpu_mem_usage = config .low_cpu_mem_usage ,
866910 onload_self = True ,
867- group_id = name ,
911+ group_id = f" { config . module_prefix } { name } " ,
868912 )
869913 _apply_group_offloading_hook (submodule , group , config = config )
870914 modules_with_group_offloading .add (name )
@@ -911,7 +955,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
911955 record_stream = config .record_stream ,
912956 low_cpu_mem_usage = config .low_cpu_mem_usage ,
913957 onload_self = True ,
914- group_id = name ,
958+ group_id = f" { config . module_prefix } { name } " ,
915959 )
916960 _apply_group_offloading_hook (parent_module , group , config = config )
917961
@@ -966,8 +1010,8 @@ def _apply_lazy_group_offloading_hook(
9661010 if registry .get_hook (_GROUP_OFFLOADING ) is None :
9671011 hook = GroupOffloadingHook (group , config = config )
9681012 registry .register_hook (hook , _GROUP_OFFLOADING )
969-
970- lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook (pin_groups = config .pin_groups )
1013+
1014+ lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook (pin_groups = config .pin_groups )
9711015 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
9721016
9731017
0 commit comments