@@ -70,12 +70,55 @@ def __init__(
7070 self .stream = stream
7171 self .cpu_param_dict = cpu_param_dict
7272 self .onload_self = onload_self
73+ # Track if we've pinned our group's memory
74+ self .pinned_memory = False
7375
7476 if self .stream is not None and self .cpu_param_dict is None :
7577 raise ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
7678
79+ def pin_memory_ (self ):
80+ r"""Pin the memory of this group's parameters for faster transfer."""
81+ if self .stream is not None and not self .pinned_memory :
82+ # Create the pinned memory dict just for this group's parameters
83+ self .cpu_param_dict = {}
84+ for group_module in self .modules :
85+ for param in group_module .parameters ():
86+ if param .device == self .offload_device :
87+ pinned_data = param .data .pin_memory ()
88+ self .cpu_param_dict [param ] = pinned_data
89+ param .data = pinned_data
90+ if self .parameters is not None :
91+ for param in self .parameters :
92+ if param .device == self .offload_device :
93+ pinned_data = param .data .pin_memory ()
94+ self .cpu_param_dict [param ] = pinned_data
95+ param .data = pinned_data
96+ self .pinned_memory = True
97+
98+ def unpin_memory_ (self ):
99+ r"""Unpin the memory of this group's parameters to free up CPU RAM."""
100+ if self .stream is not None and self .pinned_memory :
101+ # Only unpin if we're currently on CPU (i.e., offloaded)
102+ for group_module in self .modules :
103+ if not any (p .device == self .onload_device for p in group_module .parameters ()):
104+ for param in group_module .parameters ():
105+ if param in self .cpu_param_dict and param .device == self .offload_device :
106+ # Create a new non-pinned copy and replace
107+ param .data = param .data .clone ()
108+ if self .parameters is not None :
109+ for param in self .parameters :
110+ if param in self .cpu_param_dict and param .device == self .offload_device :
111+ param .data = param .data .clone ()
112+ # Clear the CPU param dict
113+ self .cpu_param_dict = {}
114+ self .pinned_memory = False
115+
77116 def onload_ (self ):
78117 r"""Onloads the group of modules to the onload_device."""
118+ # Pin memory before onloading
119+ if self .stream is not None and not self .pinned_memory :
120+ self .pin_memory_ ()
121+
79122 context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
80123 if self .stream is not None :
81124 # Wait for previous Host->Device transfer to complete
@@ -107,6 +150,9 @@ def offload_(self):
107150 if self .buffers is not None :
108151 for buffer in self .buffers :
109152 buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
153+
154+ # After offloading, we can unpin the memory if configured to do so
155+ # We'll keep it pinned by default for better performance
110156
111157
112158class GroupOffloadingHook (ModelHook ):
@@ -123,9 +169,11 @@ def __init__(
123169 self ,
124170 group : ModuleGroup ,
125171 next_group : Optional [ModuleGroup ] = None ,
172+ unpin_after_use : bool = False ,
126173 ) -> None :
127174 self .group = group
128175 self .next_group = next_group
176+ self .unpin_after_use = unpin_after_use
129177
130178 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
131179 if self .group .offload_leader == module :
@@ -154,6 +202,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
154202 def post_forward (self , module : torch .nn .Module , output ):
155203 if self .group .offload_leader == module :
156204 self .group .offload_ ()
205+ # After offloading, we can optionally unpin memory to free up CPU RAM
206+ # This is most useful for large models where CPU RAM is limited
207+ if self .unpin_after_use and self .group .pinned_memory :
208+ self .group .unpin_memory_ ()
157209 return output
158210
159211
@@ -268,6 +320,7 @@ def apply_group_offloading(
268320 num_blocks_per_group : Optional [int ] = None ,
269321 non_blocking : bool = False ,
270322 use_stream : bool = False ,
323+ unpin_after_use : bool = False ,
271324) -> None :
272325 r"""
273326 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -314,6 +367,10 @@ def apply_group_offloading(
314367 use_stream (`bool`, defaults to `False`):
315368 If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
316369 overlapping computation and data transfer.
370+ unpin_after_use (`bool`, defaults to `False`):
371+ If True, pinned memory is released after a module group has been offloaded to CPU. This reduces CPU memory
372+ usage at the cost of potentially slower data transfer when the group is loaded again. Useful for large models
373+ with limited CPU RAM.
317374
318375 Example:
319376 ```python
@@ -331,6 +388,7 @@ def apply_group_offloading(
331388 ... offload_type="block_level",
332389 ... num_blocks_per_group=2,
333390 ... use_stream=True,
391+ ... unpin_after_use=False, # Set to True to reduce CPU memory usage
334392 ... )
335393 ```
336394 """
@@ -349,10 +407,10 @@ def apply_group_offloading(
349407 raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
350408
351409 _apply_group_offloading_block_level (
352- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream
410+ module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , unpin_after_use
353411 )
354412 elif offload_type == "leaf_level" :
355- _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
413+ _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream , unpin_after_use )
356414 else :
357415 raise ValueError (f"Unsupported offload_type: { offload_type } " )
358416
@@ -364,6 +422,7 @@ def _apply_group_offloading_block_level(
364422 onload_device : torch .device ,
365423 non_blocking : bool ,
366424 stream : Optional [torch .cuda .Stream ] = None ,
425+ unpin_after_use : bool = False ,
367426) -> None :
368427 r"""
369428 This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -384,12 +443,9 @@ def _apply_group_offloading_block_level(
384443 for overlapping computation and data transfer.
385444 """
386445
387- # Create a pinned CPU parameter dict for async data transfer if streams are to be used
388- cpu_param_dict = None
389- if stream is not None :
390- for param in module .parameters ():
391- param .data = param .data .cpu ().pin_memory ()
392- cpu_param_dict = {param : param .data for param in module .parameters ()}
446+ # With progressive pinning approach, we'll initialize an empty CPU parameter dict
447+ # and pin memory only when needed by each group
448+ cpu_param_dict = {} if stream is not None else None
393449
394450 # Create module groups for ModuleList and Sequential blocks
395451 modules_with_group_offloading = set ()
@@ -425,7 +481,7 @@ def _apply_group_offloading_block_level(
425481 )
426482
427483 for group_module in group .modules :
428- _apply_group_offloading_hook (group_module , group , next_group )
484+ _apply_group_offloading_hook (group_module , group , next_group , unpin_after_use )
429485
430486 # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
431487 # when the forward pass of this module is called. This is because the top-level module is not
@@ -452,7 +508,7 @@ def _apply_group_offloading_block_level(
452508 onload_self = True ,
453509 )
454510 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
455- _apply_group_offloading_hook (module , unmatched_group , next_group )
511+ _apply_group_offloading_hook (module , unmatched_group , next_group , unpin_after_use )
456512
457513
458514def _apply_group_offloading_leaf_level (
@@ -461,6 +517,7 @@ def _apply_group_offloading_leaf_level(
461517 onload_device : torch .device ,
462518 non_blocking : bool ,
463519 stream : Optional [torch .cuda .Stream ] = None ,
520+ unpin_after_use : bool = False ,
464521) -> None :
465522 r"""
466523 This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -483,12 +540,9 @@ def _apply_group_offloading_leaf_level(
483540 for overlapping computation and data transfer.
484541 """
485542
486- # Create a pinned CPU parameter dict for async data transfer if streams are to be used
487- cpu_param_dict = None
488- if stream is not None :
489- for param in module .parameters ():
490- param .data = param .data .cpu ().pin_memory ()
491- cpu_param_dict = {param : param .data for param in module .parameters ()}
543+ # With progressive pinning approach, we'll initialize an empty CPU parameter dict
544+ # and pin memory only when needed by each group
545+ cpu_param_dict = {} if stream is not None else None
492546
493547 # Create module groups for leaf modules and apply group offloading hooks
494548 modules_with_group_offloading = set ()
@@ -506,7 +560,7 @@ def _apply_group_offloading_leaf_level(
506560 cpu_param_dict = cpu_param_dict ,
507561 onload_self = True ,
508562 )
509- _apply_group_offloading_hook (submodule , group , None )
563+ _apply_group_offloading_hook (submodule , group , None , unpin_after_use )
510564 modules_with_group_offloading .add (name )
511565
512566 # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -551,7 +605,7 @@ def _apply_group_offloading_leaf_level(
551605 cpu_param_dict = cpu_param_dict ,
552606 onload_self = True ,
553607 )
554- _apply_group_offloading_hook (parent_module , group , None )
608+ _apply_group_offloading_hook (parent_module , group , None , unpin_after_use )
555609
556610 if stream is not None :
557611 # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -577,13 +631,14 @@ def _apply_group_offloading_hook(
577631 module : torch .nn .Module ,
578632 group : ModuleGroup ,
579633 next_group : Optional [ModuleGroup ] = None ,
634+ unpin_after_use : bool = False ,
580635) -> None :
581636 registry = HookRegistry .check_if_exists_or_initialize (module )
582637
583638 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
584639 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
585640 if registry .get_hook (_GROUP_OFFLOADING ) is None :
586- hook = GroupOffloadingHook (group , next_group )
641+ hook = GroupOffloadingHook (group , next_group , unpin_after_use )
587642 registry .register_hook (hook , _GROUP_OFFLOADING )
588643
589644
0 commit comments