2929logger = get_logger (__name__ ) # pylint: disable=invalid-name
3030
3131
32+ class PinnedGroupManager :
33+ """
34+ Manages a sliding window of pinned module groups to limit total CPU memory usage.
35+ Only keeps up to max_pinned_groups pinned at any given time, unpinning oldest
36+ ones when the limit is reached.
37+ """
38+ def __init__ (self , max_pinned_groups : Optional [int ] = None ):
39+ self .max_pinned_groups = max_pinned_groups
40+ self .pinned_groups = []
41+
42+ def register_group (self , group : "ModuleGroup" ) -> None :
43+ """Register a group with the manager"""
44+ if self .max_pinned_groups is not None :
45+ group ._pinned_group_manager = self
46+
47+ def on_group_pinned (self , group : "ModuleGroup" ) -> None :
48+ """Called when a group is pinned, manages the sliding window"""
49+ if self .max_pinned_groups is None :
50+ return
51+
52+ # Add the newly pinned group to our tracking list
53+ if group not in self .pinned_groups :
54+ self .pinned_groups .append (group )
55+
56+ # If we've exceeded our limit, unpin the oldest group(s)
57+ while len (self .pinned_groups ) > self .max_pinned_groups :
58+ oldest_group = self .pinned_groups .pop (0 )
59+ # Only unpin if it's not the group we just pinned
60+ if oldest_group != group and oldest_group .pinned_memory :
61+ oldest_group .unpin_memory_ ()
62+
63+ def on_group_unpinned (self , group : "ModuleGroup" ) -> None :
64+ """Called when a group is manually unpinned"""
65+ if self .max_pinned_groups is None :
66+ return
67+
68+ # Remove the group from our tracking list
69+ if group in self .pinned_groups :
70+ self .pinned_groups .remove (group )
71+
72+
3273# fmt: off
3374_GROUP_OFFLOADING = "group_offloading"
3475_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -58,6 +99,7 @@ def __init__(
5899 stream : Optional [torch .cuda .Stream ] = None ,
59100 cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] = None ,
60101 onload_self : bool = True ,
102+ pinned_group_manager : Optional ["PinnedGroupManager" ] = None ,
61103 ) -> None :
62104 self .modules = modules
63105 self .offload_device = offload_device
@@ -72,6 +114,8 @@ def __init__(
72114 self .onload_self = onload_self
73115 # Track if we've pinned our group's memory
74116 self .pinned_memory = False
117+ # Reference to the pinned group manager for sliding window functionality
118+ self ._pinned_group_manager = pinned_group_manager
75119
76120 if self .stream is not None and self .cpu_param_dict is None :
77121 raise ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
@@ -94,6 +138,10 @@ def pin_memory_(self):
94138 self .cpu_param_dict [param ] = pinned_data
95139 param .data = pinned_data
96140 self .pinned_memory = True
141+
142+ # Notify the manager that this group has been pinned
143+ if hasattr (self , "_pinned_group_manager" ) and self ._pinned_group_manager is not None :
144+ self ._pinned_group_manager .on_group_pinned (self )
97145
98146 def unpin_memory_ (self ):
99147 r"""Unpin the memory of this group's parameters to free up CPU RAM."""
@@ -112,6 +160,10 @@ def unpin_memory_(self):
112160 # Clear the CPU param dict
113161 self .cpu_param_dict = {}
114162 self .pinned_memory = False
163+
164+ # Notify the manager that this group has been unpinned
165+ if hasattr (self , "_pinned_group_manager" ) and self ._pinned_group_manager is not None :
166+ self ._pinned_group_manager .on_group_unpinned (self )
115167
116168 def onload_ (self ):
117169 r"""Onloads the group of modules to the onload_device."""
@@ -321,6 +373,7 @@ def apply_group_offloading(
321373 non_blocking : bool = False ,
322374 use_stream : bool = False ,
323375 unpin_after_use : bool = False ,
376+ max_pinned_groups : Optional [int ] = None ,
324377) -> None :
325378 r"""
326379 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -371,6 +424,10 @@ def apply_group_offloading(
371424 If True, pinned memory is released after a module group has been offloaded to CPU. This reduces CPU memory
372425 usage at the cost of potentially slower data transfer when the group is loaded again. Useful for large models
373426 with limited CPU RAM.
427+ max_pinned_groups (`int`, *optional*):
428+ If set, limits the number of module groups that can have pinned memory at any given time. When this limit
429+ is reached, the oldest pinned groups are unpinned to make room for new ones. This implements a sliding window
430+ approach to manage CPU memory usage while maintaining good performance for the most recently used groups.
374431
375432 Example:
376433 ```python
@@ -389,6 +446,7 @@ def apply_group_offloading(
389446 ... num_blocks_per_group=2,
390447 ... use_stream=True,
391448 ... unpin_after_use=False, # Set to True to reduce CPU memory usage
449+ ... max_pinned_groups=5, # Limit to 5 pinned groups at a time
392450 ... )
393451 ```
394452 """
@@ -401,16 +459,26 @@ def apply_group_offloading(
401459 raise ValueError ("Using streams for data transfer requires a CUDA device." )
402460
403461 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
462+
463+ # Create a pinned group manager if max_pinned_groups is set
464+ pinned_group_manager = None
465+ if max_pinned_groups is not None and stream is not None :
466+ pinned_group_manager = PinnedGroupManager (max_pinned_groups )
467+ logger .info (f"Using sliding window approach with maximum of { max_pinned_groups } pinned groups" )
404468
405469 if offload_type == "block_level" :
406470 if num_blocks_per_group is None :
407471 raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
408472
409473 _apply_group_offloading_block_level (
410- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , unpin_after_use
474+ module , num_blocks_per_group , offload_device , onload_device , non_blocking ,
475+ stream , unpin_after_use , pinned_group_manager
411476 )
412477 elif offload_type == "leaf_level" :
413- _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream , unpin_after_use )
478+ _apply_group_offloading_leaf_level (
479+ module , offload_device , onload_device , non_blocking ,
480+ stream , unpin_after_use , pinned_group_manager
481+ )
414482 else :
415483 raise ValueError (f"Unsupported offload_type: { offload_type } " )
416484
@@ -423,6 +491,7 @@ def _apply_group_offloading_block_level(
423491 non_blocking : bool ,
424492 stream : Optional [torch .cuda .Stream ] = None ,
425493 unpin_after_use : bool = False ,
494+ pinned_group_manager : Optional [PinnedGroupManager ] = None ,
426495) -> None :
427496 r"""
428497 This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -469,6 +538,7 @@ def _apply_group_offloading_block_level(
469538 stream = stream ,
470539 cpu_param_dict = cpu_param_dict ,
471540 onload_self = stream is None ,
541+ pinned_group_manager = pinned_group_manager ,
472542 )
473543 matched_module_groups .append (group )
474544 for j in range (i , i + len (current_modules )):
@@ -506,6 +576,7 @@ def _apply_group_offloading_block_level(
506576 stream = None ,
507577 cpu_param_dict = None ,
508578 onload_self = True ,
579+ pinned_group_manager = pinned_group_manager ,
509580 )
510581 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
511582 _apply_group_offloading_hook (module , unmatched_group , next_group , unpin_after_use )
@@ -518,6 +589,7 @@ def _apply_group_offloading_leaf_level(
518589 non_blocking : bool ,
519590 stream : Optional [torch .cuda .Stream ] = None ,
520591 unpin_after_use : bool = False ,
592+ pinned_group_manager : Optional [PinnedGroupManager ] = None ,
521593) -> None :
522594 r"""
523595 This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -559,6 +631,7 @@ def _apply_group_offloading_leaf_level(
559631 stream = stream ,
560632 cpu_param_dict = cpu_param_dict ,
561633 onload_self = True ,
634+ pinned_group_manager = pinned_group_manager ,
562635 )
563636 _apply_group_offloading_hook (submodule , group , None , unpin_after_use )
564637 modules_with_group_offloading .add (name )
@@ -604,6 +677,7 @@ def _apply_group_offloading_leaf_level(
604677 stream = stream ,
605678 cpu_param_dict = cpu_param_dict ,
606679 onload_self = True ,
680+ pinned_group_manager = pinned_group_manager ,
607681 )
608682 _apply_group_offloading_hook (parent_module , group , None , unpin_after_use )
609683
@@ -623,6 +697,7 @@ def _apply_group_offloading_leaf_level(
623697 stream = None ,
624698 cpu_param_dict = None ,
625699 onload_self = True ,
700+ pinned_group_manager = pinned_group_manager ,
626701 )
627702 _apply_lazy_group_offloading_hook (module , unmatched_group , None )
628703
0 commit comments