@@ -35,36 +35,37 @@ class PinnedGroupManager:
3535 Only keeps up to max_pinned_groups pinned at any given time, unpinning oldest
3636 ones when the limit is reached.
3737 """
38+
3839 def __init__ (self , max_pinned_groups : Optional [int ] = None ):
3940 self .max_pinned_groups = max_pinned_groups
4041 self .pinned_groups = []
41-
42+
4243 def register_group (self , group : "ModuleGroup" ) -> None :
4344 """Register a group with the manager"""
4445 if self .max_pinned_groups is not None :
4546 group ._pinned_group_manager = self
46-
47+
4748 def on_group_pinned (self , group : "ModuleGroup" ) -> None :
4849 """Called when a group is pinned, manages the sliding window"""
4950 if self .max_pinned_groups is None :
5051 return
51-
52+
5253 # Add the newly pinned group to our tracking list
5354 if group not in self .pinned_groups :
5455 self .pinned_groups .append (group )
55-
56+
5657 # If we've exceeded our limit, unpin the oldest group(s)
5758 while len (self .pinned_groups ) > self .max_pinned_groups :
5859 oldest_group = self .pinned_groups .pop (0 )
5960 # Only unpin if it's not the group we just pinned
6061 if oldest_group != group and oldest_group .pinned_memory :
6162 oldest_group .unpin_memory_ ()
62-
63+
6364 def on_group_unpinned (self , group : "ModuleGroup" ) -> None :
6465 """Called when a group is manually unpinned"""
6566 if self .max_pinned_groups is None :
6667 return
67-
68+
6869 # Remove the group from our tracking list
6970 if group in self .pinned_groups :
7071 self .pinned_groups .remove (group )
@@ -138,7 +139,7 @@ def pin_memory_(self):
138139 self .cpu_param_dict [param ] = pinned_data
139140 param .data = pinned_data
140141 self .pinned_memory = True
141-
142+
142143 # Notify the manager that this group has been pinned
143144 if hasattr (self , "_pinned_group_manager" ) and self ._pinned_group_manager is not None :
144145 self ._pinned_group_manager .on_group_pinned (self )
@@ -160,7 +161,7 @@ def unpin_memory_(self):
160161 # Clear the CPU param dict
161162 self .cpu_param_dict = {}
162163 self .pinned_memory = False
163-
164+
164165 # Notify the manager that this group has been unpinned
165166 if hasattr (self , "_pinned_group_manager" ) and self ._pinned_group_manager is not None :
166167 self ._pinned_group_manager .on_group_unpinned (self )
@@ -170,7 +171,7 @@ def onload_(self):
170171 # Pin memory before onloading
171172 if self .stream is not None and not self .pinned_memory :
172173 self .pin_memory_ ()
173-
174+
174175 context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
175176 if self .stream is not None :
176177 # Wait for previous Host->Device transfer to complete
@@ -202,7 +203,7 @@ def offload_(self):
202203 if self .buffers is not None :
203204 for buffer in self .buffers :
204205 buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
205-
206+
206207 # After offloading, we can unpin the memory if configured to do so
207208 # We'll keep it pinned by default for better performance
208209
@@ -229,6 +230,10 @@ def __init__(
229230
230231 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
231232 if self .group .offload_leader == module :
233+ # Make sure we pin memory first (if using streams) before offloading
234+ if self .group .stream is not None and not self .group .pinned_memory :
235+ self .group .pin_memory_ ()
236+ # Now it's safe to offload
232237 self .group .offload_ ()
233238 return module
234239
@@ -459,7 +464,7 @@ def apply_group_offloading(
459464 raise ValueError ("Using streams for data transfer requires a CUDA device." )
460465
461466 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
462-
467+
463468 # Create a pinned group manager if max_pinned_groups is set
464469 pinned_group_manager = None
465470 if max_pinned_groups is not None and stream is not None :
@@ -471,13 +476,18 @@ def apply_group_offloading(
471476 raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
472477
473478 _apply_group_offloading_block_level (
474- module , num_blocks_per_group , offload_device , onload_device , non_blocking ,
475- stream , unpin_after_use , pinned_group_manager
479+ module ,
480+ num_blocks_per_group ,
481+ offload_device ,
482+ onload_device ,
483+ non_blocking ,
484+ stream ,
485+ unpin_after_use ,
486+ pinned_group_manager ,
476487 )
477488 elif offload_type == "leaf_level" :
478489 _apply_group_offloading_leaf_level (
479- module , offload_device , onload_device , non_blocking ,
480- stream , unpin_after_use , pinned_group_manager
490+ module , offload_device , onload_device , non_blocking , stream , unpin_after_use , pinned_group_manager
481491 )
482492 else :
483493 raise ValueError (f"Unsupported offload_type: { offload_type } " )
0 commit comments