@@ -59,7 +59,6 @@ def __init__(
5959 buffers : Optional [List [torch .Tensor ]] = None ,
6060 non_blocking : bool = False ,
6161 stream : Optional [torch .cuda .Stream ] = None ,
62- cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] = None ,
6362 onload_self : bool = True ,
6463 ) -> None :
6564 self .modules = modules
@@ -71,47 +70,10 @@ def __init__(
7170 self .buffers = buffers
7271 self .non_blocking = non_blocking or stream is not None
7372 self .stream = stream
74- self .cpu_param_dict = cpu_param_dict
7573 self .onload_self = onload_self
76- # We still track if we've prepared the CPU dict for compatibility
77- self .cpu_dict_prepared = False
78-
79- if self .stream is not None and self .cpu_param_dict is None :
80- # Now we'll create the dict on demand
81- self .cpu_param_dict = {}
82-
83- def pin_memory_ (self ):
84- r"""Prepare the CPU parameter dictionary for reference (no pinning)."""
85- if self .stream is not None and not self .cpu_dict_prepared :
86- # Create a reference-only CPU parameter dict without pinning
87- for group_module in self .modules :
88- for param in group_module .parameters ():
89- if param .device == self .offload_device :
90- # Store a reference without pinning
91- self .cpu_param_dict [param ] = param .data
92- if self .parameters is not None :
93- for param in self .parameters :
94- if param .device == self .offload_device :
95- # Store a reference without pinning
96- self .cpu_param_dict [param ] = param .data
97- self .cpu_dict_prepared = True
98- # For API compatibility
99- self .pinned_memory = True
100-
101- def unpin_memory_ (self ):
102- r"""No-op method kept for API compatibility."""
103- # No need to unpin since we're not pinning memory anymore
104- if self .stream is not None and self .pinned_memory :
105- # Just mark as unpinned for compatibility
106- self .pinned_memory = False
107- self .cpu_dict_prepared = False
10874
10975 def onload_ (self ):
11076 r"""Onloads the group of modules to the onload_device."""
111- # Prepare CPU dict before onloading
112- if self .stream is not None and not self .cpu_dict_prepared :
113- self .pin_memory_ () # This now just prepares the CPU dict
114-
11577 context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
11678 if self .stream is not None :
11779 # Wait for previous Host->Device transfer to complete
@@ -129,20 +91,19 @@ def onload_(self):
12991
13092 def offload_ (self ):
13193 r"""Offloads the group of modules to the offload_device."""
94+ # Synchronize if using stream
13295 if self .stream is not None :
13396 torch .cuda .current_stream ().synchronize ()
134- for group_module in self .modules :
135- for param in group_module .parameters ():
136- param .data = self .cpu_param_dict [param ]
137- else :
138- for group_module in self .modules :
139- group_module .to (self .offload_device , non_blocking = self .non_blocking )
140- if self .parameters is not None :
141- for param in self .parameters :
142- param .data = param .data .to (self .offload_device , non_blocking = self .non_blocking )
143- if self .buffers is not None :
144- for buffer in self .buffers :
145- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
97+
98+ # Use regular to() method for all cases - much simpler!
99+ for group_module in self .modules :
100+ group_module .to (self .offload_device , non_blocking = self .non_blocking )
101+ if self .parameters is not None :
102+ for param in self .parameters :
103+ param .data = param .data .to (self .offload_device , non_blocking = self .non_blocking )
104+ if self .buffers is not None :
105+ for buffer in self .buffers :
106+ buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
146107
147108 # After offloading, we can unpin the memory if configured to do so
148109 # We'll keep it pinned by default for better performance
@@ -162,18 +123,13 @@ def __init__(
162123 self ,
163124 group : ModuleGroup ,
164125 next_group : Optional [ModuleGroup ] = None ,
165- unpin_after_use : bool = False ,
166126 ) -> None :
167127 self .group = group
168128 self .next_group = next_group
169- self .unpin_after_use = unpin_after_use
170129
171130 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
172131 if self .group .offload_leader == module :
173- # Make sure we prepare CPU dict first (if using streams) before offloading
174- if self .group .stream is not None and not self .group .cpu_dict_prepared :
175- self .group .pin_memory_ () # This now just prepares the CPU dict
176- # Now it's safe to offload
132+ # Offload to CPU
177133 self .group .offload_ ()
178134 return module
179135
@@ -199,9 +155,6 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
199155 def post_forward (self , module : torch .nn .Module , output ):
200156 if self .group .offload_leader == module :
201157 self .group .offload_ ()
202- # This is now a no-op but kept for API compatibility
203- if self .unpin_after_use and self .group .cpu_dict_prepared :
204- self .group .unpin_memory_ ()
205158 return output
206159
207160
@@ -316,7 +269,6 @@ def apply_group_offloading(
316269 num_blocks_per_group : Optional [int ] = None ,
317270 non_blocking : bool = False ,
318271 use_stream : bool = False ,
319- unpin_after_use : bool = False ,
320272) -> None :
321273 r"""
322274 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -363,8 +315,6 @@ def apply_group_offloading(
363315 use_stream (`bool`, defaults to `False`):
364316 If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
365317 overlapping computation and data transfer.
366- unpin_after_use (`bool`, defaults to `False`):
367- Legacy parameter kept for API compatibility. Has no effect as we no longer use pinned memory.
368318
369319 Example:
370320 ```python
@@ -382,7 +332,6 @@ def apply_group_offloading(
382332 ... offload_type="block_level",
383333 ... num_blocks_per_group=2,
384334 ... use_stream=True,
385- ... unpin_after_use=False, # Legacy parameter, no effect
386335 ... )
387336 ```
388337 """
@@ -409,11 +358,10 @@ def apply_group_offloading(
409358 onload_device ,
410359 non_blocking ,
411360 stream ,
412- unpin_after_use ,
413361 )
414362 elif offload_type == "leaf_level" :
415363 _apply_group_offloading_leaf_level (
416- module , offload_device , onload_device , non_blocking , stream , unpin_after_use
364+ module , offload_device , onload_device , non_blocking , stream
417365 )
418366 else :
419367 raise ValueError (f"Unsupported offload_type: { offload_type } " )
@@ -426,7 +374,6 @@ def _apply_group_offloading_block_level(
426374 onload_device : torch .device ,
427375 non_blocking : bool ,
428376 stream : Optional [torch .cuda .Stream ] = None ,
429- unpin_after_use : bool = False ,
430377) -> None :
431378 r"""
432379 This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -447,9 +394,7 @@ def _apply_group_offloading_block_level(
447394 for overlapping computation and data transfer.
448395 """
449396
450- # With progressive pinning approach, we'll initialize an empty CPU parameter dict
451- # and pin memory only when needed by each group
452- cpu_param_dict = {} if stream is not None else None
397+ # We no longer need a CPU parameter dictionary
453398
454399 # Create module groups for ModuleList and Sequential blocks
455400 modules_with_group_offloading = set ()
@@ -471,7 +416,6 @@ def _apply_group_offloading_block_level(
471416 onload_leader = current_modules [0 ],
472417 non_blocking = non_blocking ,
473418 stream = stream ,
474- cpu_param_dict = cpu_param_dict ,
475419 onload_self = stream is None ,
476420 )
477421 matched_module_groups .append (group )
@@ -485,7 +429,7 @@ def _apply_group_offloading_block_level(
485429 )
486430
487431 for group_module in group .modules :
488- _apply_group_offloading_hook (group_module , group , next_group , unpin_after_use )
432+ _apply_group_offloading_hook (group_module , group , next_group )
489433
490434 # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
491435 # when the forward pass of this module is called. This is because the top-level module is not
@@ -508,11 +452,10 @@ def _apply_group_offloading_block_level(
508452 buffers = buffers ,
509453 non_blocking = False ,
510454 stream = None ,
511- cpu_param_dict = None ,
512455 onload_self = True ,
513456 )
514457 next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
515- _apply_group_offloading_hook (module , unmatched_group , next_group , unpin_after_use )
458+ _apply_group_offloading_hook (module , unmatched_group , next_group )
516459
517460
518461def _apply_group_offloading_leaf_level (
@@ -521,7 +464,6 @@ def _apply_group_offloading_leaf_level(
521464 onload_device : torch .device ,
522465 non_blocking : bool ,
523466 stream : Optional [torch .cuda .Stream ] = None ,
524- unpin_after_use : bool = False ,
525467) -> None :
526468 r"""
527469 This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -544,9 +486,7 @@ def _apply_group_offloading_leaf_level(
544486 for overlapping computation and data transfer.
545487 """
546488
547- # With progressive pinning approach, we'll initialize an empty CPU parameter dict
548- # and pin memory only when needed by each group
549- cpu_param_dict = {} if stream is not None else None
489+ # We no longer need a CPU parameter dictionary
550490
551491 # Create module groups for leaf modules and apply group offloading hooks
552492 modules_with_group_offloading = set ()
@@ -561,10 +501,9 @@ def _apply_group_offloading_leaf_level(
561501 onload_leader = submodule ,
562502 non_blocking = non_blocking ,
563503 stream = stream ,
564- cpu_param_dict = cpu_param_dict ,
565504 onload_self = True ,
566505 )
567- _apply_group_offloading_hook (submodule , group , None , unpin_after_use )
506+ _apply_group_offloading_hook (submodule , group , None )
568507 modules_with_group_offloading .add (name )
569508
570509 # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -606,10 +545,9 @@ def _apply_group_offloading_leaf_level(
606545 buffers = buffers ,
607546 non_blocking = non_blocking ,
608547 stream = stream ,
609- cpu_param_dict = cpu_param_dict ,
610548 onload_self = True ,
611549 )
612- _apply_group_offloading_hook (parent_module , group , None , unpin_after_use )
550+ _apply_group_offloading_hook (parent_module , group , None )
613551
614552 if stream is not None :
615553 # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -625,7 +563,6 @@ def _apply_group_offloading_leaf_level(
625563 buffers = None ,
626564 non_blocking = False ,
627565 stream = None ,
628- cpu_param_dict = None ,
629566 onload_self = True ,
630567 )
631568 _apply_lazy_group_offloading_hook (module , unmatched_group , None )
@@ -635,14 +572,13 @@ def _apply_group_offloading_hook(
635572 module : torch .nn .Module ,
636573 group : ModuleGroup ,
637574 next_group : Optional [ModuleGroup ] = None ,
638- unpin_after_use : bool = False ,
639575) -> None :
640576 registry = HookRegistry .check_if_exists_or_initialize (module )
641577
642578 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
643579 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
644580 if registry .get_hook (_GROUP_OFFLOADING ) is None :
645- hook = GroupOffloadingHook (group , next_group , unpin_after_use )
581+ hook = GroupOffloadingHook (group , next_group )
646582 registry .register_hook (hook , _GROUP_OFFLOADING )
647583
648584
0 commit comments