2929logger = get_logger (__name__ ) # pylint: disable=invalid-name
3030
3131
32- class PinnedGroupManager :
33- """
34- Manages a sliding window of pinned module groups to limit total CPU memory usage.
35- Only keeps up to max_pinned_groups pinned at any given time, unpinning oldest
36- ones when the limit is reached.
37- """
38-
39- def __init__ (self , max_pinned_groups : Optional [int ] = None ):
40- self .max_pinned_groups = max_pinned_groups
41- self .pinned_groups = []
42-
43- def register_group (self , group : "ModuleGroup" ) -> None :
44- """Register a group with the manager"""
45- if self .max_pinned_groups is not None :
46- group ._pinned_group_manager = self
47-
48- def on_group_pinned (self , group : "ModuleGroup" ) -> None :
49- """Called when a group is pinned, manages the sliding window"""
50- if self .max_pinned_groups is None :
51- return
52-
53- # Add the newly pinned group to our tracking list
54- if group not in self .pinned_groups :
55- self .pinned_groups .append (group )
56-
57- # If we've exceeded our limit, unpin the oldest group(s)
58- while len (self .pinned_groups ) > self .max_pinned_groups :
59- oldest_group = self .pinned_groups .pop (0 )
60- # Only unpin if it's not the group we just pinned
61- if oldest_group != group and oldest_group .pinned_memory :
62- oldest_group .unpin_memory_ ()
63-
64- def on_group_unpinned (self , group : "ModuleGroup" ) -> None :
65- """Called when a group is manually unpinned"""
66- if self .max_pinned_groups is None :
67- return
68-
69- # Remove the group from our tracking list
70- if group in self .pinned_groups :
71- self .pinned_groups .remove (group )
32+ # Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
7233
7334
7435# fmt: off
@@ -100,7 +61,6 @@ def __init__(
10061 stream : Optional [torch .cuda .Stream ] = None ,
10162 cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] = None ,
10263 onload_self : bool = True ,
103- pinned_group_manager : Optional ["PinnedGroupManager" ] = None ,
10464 ) -> None :
10565 self .modules = modules
10666 self .offload_device = offload_device
@@ -113,58 +73,38 @@ def __init__(
11373 self .stream = stream
11474 self .cpu_param_dict = cpu_param_dict
11575 self .onload_self = onload_self
116- # Track if we've pinned our group's memory
117- self .pinned_memory = False
118- # Reference to the pinned group manager for sliding window functionality
119- self ._pinned_group_manager = pinned_group_manager
76+ # We still track if we've prepared the CPU dict for compatibility
77+ self .cpu_dict_prepared = False
12078
12179 if self .stream is not None and self .cpu_param_dict is None :
122- raise ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
80+ # Now we'll create the dict on demand
81+ self .cpu_param_dict = {}
12382
12483 def pin_memory_ (self ):
125- r"""Pin the memory of this group's parameters for faster transfer."""
126- if self .stream is not None and not self .pinned_memory :
127- # Create the pinned memory dict just for this group's parameters
128- self .cpu_param_dict = {}
84+ r"""Prepare the CPU parameter dictionary for reference (no pinning)."""
85+ if self .stream is not None and not self .cpu_dict_prepared :
86+ # Create a reference-only CPU parameter dict without pinning
12987 for group_module in self .modules :
13088 for param in group_module .parameters ():
13189 if param .device == self .offload_device :
132- pinned_data = param .data .pin_memory ()
133- self .cpu_param_dict [param ] = pinned_data
134- param .data = pinned_data
90+ # Store a reference without pinning
91+ self .cpu_param_dict [param ] = param .data
13592 if self .parameters is not None :
13693 for param in self .parameters :
13794 if param .device == self .offload_device :
138- pinned_data = param .data .pin_memory ()
139- self .cpu_param_dict [param ] = pinned_data
140- param .data = pinned_data
95+ # Store a reference without pinning
96+ self .cpu_param_dict [param ] = param .data
97+ self .cpu_dict_prepared = True
98+ # For API compatibility
14199 self .pinned_memory = True
142100
143- # Notify the manager that this group has been pinned
144- if hasattr (self , "_pinned_group_manager" ) and self ._pinned_group_manager is not None :
145- self ._pinned_group_manager .on_group_pinned (self )
146-
147101 def unpin_memory_ (self ):
148- r"""Unpin the memory of this group's parameters to free up CPU RAM."""
102+ r"""No-op method kept for API compatibility."""
103+ # No need to unpin since we're not pinning memory anymore
149104 if self .stream is not None and self .pinned_memory :
150- # Only unpin if we're currently on CPU (i.e., offloaded)
151- for group_module in self .modules :
152- if not any (p .device == self .onload_device for p in group_module .parameters ()):
153- for param in group_module .parameters ():
154- if param in self .cpu_param_dict and param .device == self .offload_device :
155- # Create a new non-pinned copy and replace
156- param .data = param .data .clone ()
157- if self .parameters is not None :
158- for param in self .parameters :
159- if param in self .cpu_param_dict and param .device == self .offload_device :
160- param .data = param .data .clone ()
161- # Clear the CPU param dict
162- self .cpu_param_dict = {}
105+ # Just mark as unpinned for compatibility
163106 self .pinned_memory = False
164-
165- # Notify the manager that this group has been unpinned
166- if hasattr (self , "_pinned_group_manager" ) and self ._pinned_group_manager is not None :
167- self ._pinned_group_manager .on_group_unpinned (self )
107+ self .cpu_dict_prepared = False
168108
169109 def onload_ (self ):
170110 r"""Onloads the group of modules to the onload_device."""
@@ -378,7 +318,6 @@ def apply_group_offloading(
378318 non_blocking : bool = False ,
379319 use_stream : bool = False ,
380320 unpin_after_use : bool = False ,
381- max_pinned_groups : Optional [int ] = None ,
382321) -> None :
383322 r"""
384323 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -426,13 +365,7 @@ def apply_group_offloading(
426365 If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
427366 overlapping computation and data transfer.
428367 unpin_after_use (`bool`, defaults to `False`):
429- If True, pinned memory is released after a module group has been offloaded to CPU. This reduces CPU memory
430- usage at the cost of potentially slower data transfer when the group is loaded again. Useful for large models
431- with limited CPU RAM.
432- max_pinned_groups (`int`, *optional*):
433- If set, limits the number of module groups that can have pinned memory at any given time. When this limit
434- is reached, the oldest pinned groups are unpinned to make room for new ones. This implements a sliding window
435- approach to manage CPU memory usage while maintaining good performance for the most recently used groups.
368+ Legacy parameter kept for API compatibility. Has no effect as we no longer use pinned memory.
436369
437370 Example:
438371 ```python
@@ -450,8 +383,7 @@ def apply_group_offloading(
450383 ... offload_type="block_level",
451384 ... num_blocks_per_group=2,
452385 ... use_stream=True,
453- ... unpin_after_use=False, # Set to True to reduce CPU memory usage
454- ... max_pinned_groups=5, # Limit to 5 pinned groups at a time
386+ ... unpin_after_use=False, # Legacy parameter, no effect
455387 ... )
456388 ```
457389 """
@@ -465,11 +397,7 @@ def apply_group_offloading(
465397
466398 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
467399
468- # Create a pinned group manager if max_pinned_groups is set
469- pinned_group_manager = None
470- if max_pinned_groups is not None and stream is not None :
471- pinned_group_manager = PinnedGroupManager (max_pinned_groups )
472- logger .info (f"Using sliding window approach with maximum of { max_pinned_groups } pinned groups" )
400+ # We no longer need a pinned group manager as we're not using pinned memory
473401
474402 if offload_type == "block_level" :
475403 if num_blocks_per_group is None :
@@ -483,11 +411,10 @@ def apply_group_offloading(
483411 non_blocking ,
484412 stream ,
485413 unpin_after_use ,
486- pinned_group_manager ,
487414 )
488415 elif offload_type == "leaf_level" :
489416 _apply_group_offloading_leaf_level (
490- module , offload_device , onload_device , non_blocking , stream , unpin_after_use , pinned_group_manager
417+ module , offload_device , onload_device , non_blocking , stream , unpin_after_use
491418 )
492419 else :
493420 raise ValueError (f"Unsupported offload_type: { offload_type } " )
@@ -501,7 +428,6 @@ def _apply_group_offloading_block_level(
501428 non_blocking : bool ,
502429 stream : Optional [torch .cuda .Stream ] = None ,
503430 unpin_after_use : bool = False ,
504- pinned_group_manager : Optional [PinnedGroupManager ] = None ,
505431) -> None :
506432 r"""
507433 This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -548,7 +474,6 @@ def _apply_group_offloading_block_level(
548474 stream = stream ,
549475 cpu_param_dict = cpu_param_dict ,
550476 onload_self = stream is None ,
551- pinned_group_manager = pinned_group_manager ,
552477 )
553478 matched_module_groups .append (group )
554479 for j in range (i , i + len (current_modules )):
@@ -586,7 +511,6 @@ def _apply_group_offloading_block_level(
586511 stream = None ,
587512 cpu_param_dict = None ,
588513 onload_self = True ,
589- pinned_group_manager = pinned_group_manager ,
590514 )
591515 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
592516 _apply_group_offloading_hook (module , unmatched_group , next_group , unpin_after_use )
@@ -599,7 +523,6 @@ def _apply_group_offloading_leaf_level(
599523 non_blocking : bool ,
600524 stream : Optional [torch .cuda .Stream ] = None ,
601525 unpin_after_use : bool = False ,
602- pinned_group_manager : Optional [PinnedGroupManager ] = None ,
603526) -> None :
604527 r"""
605528 This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -641,7 +564,6 @@ def _apply_group_offloading_leaf_level(
641564 stream = stream ,
642565 cpu_param_dict = cpu_param_dict ,
643566 onload_self = True ,
644- pinned_group_manager = pinned_group_manager ,
645567 )
646568 _apply_group_offloading_hook (submodule , group , None , unpin_after_use )
647569 modules_with_group_offloading .add (name )
@@ -687,7 +609,6 @@ def _apply_group_offloading_leaf_level(
687609 stream = stream ,
688610 cpu_param_dict = cpu_param_dict ,
689611 onload_self = True ,
690- pinned_group_manager = pinned_group_manager ,
691612 )
692613 _apply_group_offloading_hook (parent_module , group , None , unpin_after_use )
693614
@@ -707,7 +628,6 @@ def _apply_group_offloading_leaf_level(
707628 stream = None ,
708629 cpu_param_dict = None ,
709630 onload_self = True ,
710- pinned_group_manager = pinned_group_manager ,
711631 )
712632 _apply_lazy_group_offloading_hook (module , unmatched_group , None )
713633
0 commit comments