@@ -60,6 +60,8 @@ class GroupOffloadingConfig:
6060 offload_to_disk_path : Optional [str ] = None
6161 stream : Optional [Union [torch .cuda .Stream , torch .Stream ]] = None
6262 block_modules : Optional [List [str ]] = None
63+ exclude_kwargs : Optional [List [str ]] = None
64+ module_prefix : Optional [str ] = ""
6365 pin_groups : Optional [Union [str , Callable ]] = None
6466
6567
@@ -156,27 +158,27 @@ def _pinned_memory_tensors(self):
156158 finally :
157159 pinned_dict = None
158160
159- def _transfer_tensor_to_device (self , tensor , source_tensor ):
161+ def _transfer_tensor_to_device (self , tensor , source_tensor , default_stream = None ):
160162 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
161163 if self .record_stream :
162- tensor .data .record_stream (self . _torch_accelerator_module . current_stream () )
164+ tensor .data .record_stream (default_stream )
163165
164- def _process_tensors_from_modules (self , pinned_memory = None ):
166+ def _process_tensors_from_modules (self , pinned_memory = None , default_stream = None ):
165167 for group_module in self .modules :
166168 for param in group_module .parameters ():
167169 source = pinned_memory [param ] if pinned_memory else param .data
168- self ._transfer_tensor_to_device (param , source )
170+ self ._transfer_tensor_to_device (param , source , default_stream )
169171 for buffer in group_module .buffers ():
170172 source = pinned_memory [buffer ] if pinned_memory else buffer .data
171- self ._transfer_tensor_to_device (buffer , source )
173+ self ._transfer_tensor_to_device (buffer , source , default_stream )
172174
173175 for param in self .parameters :
174176 source = pinned_memory [param ] if pinned_memory else param .data
175- self ._transfer_tensor_to_device (param , source )
177+ self ._transfer_tensor_to_device (param , source , default_stream )
176178
177179 for buffer in self .buffers :
178180 source = pinned_memory [buffer ] if pinned_memory else buffer .data
179- self ._transfer_tensor_to_device (buffer , source )
181+ self ._transfer_tensor_to_device (buffer , source , default_stream )
180182
181183 def _onload_from_disk (self ):
182184 if self .stream is not None :
@@ -211,10 +213,11 @@ def _onload_from_memory(self):
211213 self .stream .synchronize ()
212214
213215 context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
216+ default_stream = self ._torch_accelerator_module .current_stream () if self .stream is not None else None
214217 with context :
215218 if self .stream is not None :
216219 with self ._pinned_memory_tensors () as pinned_memory :
217- self ._process_tensors_from_modules (pinned_memory )
220+ self ._process_tensors_from_modules (pinned_memory , default_stream = default_stream )
218221 else :
219222 self ._process_tensors_from_modules (None )
220223
@@ -308,13 +311,16 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
308311 self .next_group .onload_ ()
309312
310313 should_synchronize = (
311- not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
314+ not self .group .onload_self
315+ and self .group .stream is not None
316+ and not should_onload_next_group
317+ and not self .group .record_stream
312318 )
313319 if should_synchronize :
314320 self .group .stream .synchronize ()
315321
316322 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
317- kwargs = send_to_device ( kwargs , self .group . onload_device , non_blocking = self . group . non_blocking )
323+ kwargs = self ._send_kwargs_to_device ( kwargs )
318324 return args , kwargs
319325
320326 # If the current module is the onload_leader of the group, we onload the group if it is supposed
@@ -329,7 +335,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
329335 self .next_group .onload_ ()
330336
331337 should_synchronize = (
332- not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
338+ not self .group .onload_self
339+ and self .group .stream is not None
340+ and not should_onload_next_group
341+ and not self .group .record_stream
333342 )
334343 if should_synchronize :
335344 # If this group didn't onload itself, it means it was asynchronously onloaded by the
@@ -341,7 +350,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
341350 self .group .stream .synchronize ()
342351
343352 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
344- kwargs = send_to_device ( kwargs , self .group . onload_device , non_blocking = self . group . non_blocking )
353+ kwargs = self ._send_kwargs_to_device ( kwargs )
345354 return args , kwargs
346355
347356 def post_forward (self , module : torch .nn .Module , output ):
@@ -360,10 +369,19 @@ def _is_group_on_device(self) -> bool:
360369 tensors .extend (self .group .parameters )
361370 tensors .extend (self .group .buffers )
362371
363- if len (tensors ) == 0 :
364- return True
372+ return len (tensors ) > 0 and all (t .device == self .group .onload_device for t in tensors )
365373
366- return all (t .device == self .group .onload_device for t in tensors )
374+ def _send_kwargs_to_device (self , kwargs ):
375+ exclude_kwargs = self .config .exclude_kwargs or []
376+ if exclude_kwargs :
377+ moved_kwargs = send_to_device (
378+ {k : v for k , v in kwargs .items () if k not in exclude_kwargs },
379+ self .group .onload_device ,
380+ non_blocking = self .group .non_blocking ,
381+ )
382+ kwargs .update (moved_kwargs )
383+ return kwargs
384+ return send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
367385
368386
369387class LazyPrefetchGroupOffloadingHook (ModelHook ):
@@ -524,6 +542,17 @@ def pre_forward(self, module, *args, **kwargs):
524542 return args , kwargs
525543
526544
545+ def _normalize_pin_groups (pin_groups : Optional [Union [str , Callable ]]) -> Optional [Union [str , Callable ]]:
546+ if isinstance (pin_groups , str ):
547+ normalized_pin_groups = pin_groups .lower ()
548+ if normalized_pin_groups not in {"first_last" , "all" }:
549+ raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
550+ return normalized_pin_groups
551+ if pin_groups is not None and not callable (pin_groups ):
552+ raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
553+ return pin_groups
554+
555+
527556def apply_group_offloading (
528557 module : torch .nn .Module ,
529558 onload_device : Union [str , torch .device ],
@@ -536,6 +565,7 @@ def apply_group_offloading(
536565 low_cpu_mem_usage : bool = False ,
537566 offload_to_disk_path : Optional [str ] = None ,
538567 block_modules : Optional [List [str ]] = None ,
568+ exclude_kwargs : Optional [List [str ]] = None ,
539569 pin_groups : Optional [Union [str , Callable ]] = None ,
540570) -> None :
541571 r"""
@@ -597,6 +627,10 @@ def apply_group_offloading(
597627 block_modules (`List[str]`, *optional*):
598628 List of module names that should be treated as blocks for offloading. If provided, only these modules
599629 will be considered for block-level offloading. If not provided, the default block detection logic will be used.
630+ exclude_kwargs (`List[str]`, *optional*):
631+ List of kwarg keys that should not be processed by `send_to_device`. This is useful for mutable state like
632+ caching lists that need to maintain their object identity across forward passes. If not provided, will be
633+ inferred from the module's `_skip_keys` attribute if it exists.
600634 pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
601635 Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
602636 and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
@@ -640,17 +674,14 @@ def apply_group_offloading(
640674 if offload_type == GroupOffloadingType .BLOCK_LEVEL and num_blocks_per_group is None :
641675 raise ValueError ("`num_blocks_per_group` must be provided when using `offload_type='block_level'." )
642676
643- normalized_pin_groups = pin_groups
644- if isinstance (pin_groups , str ):
645- normalized_pin_groups = pin_groups .lower ()
646- if normalized_pin_groups not in {"first_last" , "all" }:
647- raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
648- elif pin_groups is not None and not callable (pin_groups ):
649- raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
677+ pin_groups = _normalize_pin_groups (pin_groups )
678+ _raise_error_if_accelerate_model_or_sequential_hook_present (module )
650679
651- pin_groups = normalized_pin_groups
680+ if block_modules is None :
681+ block_modules = getattr (module , "_group_offload_block_modules" , None )
652682
653- _raise_error_if_accelerate_model_or_sequential_hook_present (module )
683+ if exclude_kwargs is None :
684+ exclude_kwargs = getattr (module , "_skip_keys" , None )
654685
655686 config = GroupOffloadingConfig (
656687 onload_device = onload_device ,
@@ -663,6 +694,8 @@ def apply_group_offloading(
663694 low_cpu_mem_usage = low_cpu_mem_usage ,
664695 offload_to_disk_path = offload_to_disk_path ,
665696 block_modules = block_modules ,
697+ exclude_kwargs = exclude_kwargs ,
698+ module_prefix = "" ,
666699 pin_groups = pin_groups ,
667700 )
668701 _apply_group_offloading (module , config )
@@ -701,7 +734,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
701734
702735 for name , submodule in module .named_children ():
703736 # Check if this is an explicitly defined block module
704- if name in block_modules :
737+ if block_modules and name in block_modules :
705738 # Apply block offloading to the specified submodule
706739 _apply_block_offloading_to_submodule (
707740 submodule , name , config , modules_with_group_offloading , matched_module_groups
@@ -713,7 +746,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
713746 if len (current_modules ) == 0 :
714747 continue
715748
716- group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
749+ group_id = f"{ config . module_prefix } { name } _{ i } _{ i + len (current_modules ) - 1 } "
717750 group = ModuleGroup (
718751 modules = current_modules ,
719752 offload_device = config .offload_device ,
@@ -766,7 +799,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
766799 stream = None ,
767800 record_stream = False ,
768801 onload_self = True ,
769- group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
802+ group_id = f"{ config . module_prefix } { module .__class__ .__name__ } _unmatched_group" ,
770803 )
771804 if config .stream is None :
772805 _apply_group_offloading_hook (module , unmatched_group , config = config )
@@ -797,7 +830,7 @@ def _apply_block_offloading_to_submodule(
797830 if len (current_modules ) == 0 :
798831 continue
799832
800- group_id = f"{ name } _{ i } _{ i + len (current_modules ) - 1 } "
833+ group_id = f"{ config . module_prefix } { name } _{ i } _{ i + len (current_modules ) - 1 } "
801834 group = ModuleGroup (
802835 modules = current_modules ,
803836 offload_device = config .offload_device ,
@@ -829,7 +862,7 @@ def _apply_block_offloading_to_submodule(
829862 record_stream = config .record_stream ,
830863 low_cpu_mem_usage = config .low_cpu_mem_usage ,
831864 onload_self = True ,
832- group_id = name ,
865+ group_id = f" { config . module_prefix } { name } " ,
833866 )
834867 matched_module_groups .append (group )
835868 modules_with_group_offloading .add (name )
@@ -859,7 +892,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
859892 record_stream = config .record_stream ,
860893 low_cpu_mem_usage = config .low_cpu_mem_usage ,
861894 onload_self = True ,
862- group_id = name ,
895+ group_id = f" { config . module_prefix } { name } " ,
863896 )
864897 _apply_group_offloading_hook (submodule , group , config = config )
865898 modules_with_group_offloading .add (name )
@@ -906,7 +939,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
906939 record_stream = config .record_stream ,
907940 low_cpu_mem_usage = config .low_cpu_mem_usage ,
908941 onload_self = True ,
909- group_id = name ,
942+ group_id = f" { config . module_prefix } { name } " ,
910943 )
911944 _apply_group_offloading_hook (parent_module , group , config = config )
912945
@@ -962,7 +995,7 @@ def _apply_lazy_group_offloading_hook(
962995 hook = GroupOffloadingHook (group , config = config )
963996 registry .register_hook (hook , _GROUP_OFFLOADING )
964997
965- lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook (pin_groups = config .pin_groups )
998+ lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook (pin_groups = config .pin_groups )
966999 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
9671000
9681001
0 commit comments