@@ -108,9 +108,9 @@ def unpin_memory_(self):
108108
109109 def onload_ (self ):
110110 r"""Onloads the group of modules to the onload_device."""
111- # Pin memory before onloading
112- if self .stream is not None and not self .pinned_memory :
113- self .pin_memory_ ()
111+ # Prepare CPU dict before onloading
112+ if self .stream is not None and not self .cpu_dict_prepared :
113+ self .pin_memory_ () # This now just prepares the CPU dict
114114
115115 context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
116116 if self .stream is not None :
@@ -170,9 +170,9 @@ def __init__(
170170
171171 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
172172 if self .group .offload_leader == module :
173- # Make sure we pin memory first (if using streams) before offloading
174- if self .group .stream is not None and not self .group .pinned_memory :
175- self .group .pin_memory_ ()
173+ # Make sure we prepare CPU dict first (if using streams) before offloading
174+ if self .group .stream is not None and not self .group .cpu_dict_prepared :
175+ self .group .pin_memory_ () # This now just prepares the CPU dict
176176 # Now it's safe to offload
177177 self .group .offload_ ()
178178 return module
@@ -199,9 +199,8 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
199199 def post_forward (self , module : torch .nn .Module , output ):
200200 if self .group .offload_leader == module :
201201 self .group .offload_ ()
202- # After offloading, we can optionally unpin memory to free up CPU RAM
203- # This is most useful for large models where CPU RAM is limited
204- if self .unpin_after_use and self .group .pinned_memory :
202+ # This is now a no-op but kept for API compatibility
203+ if self .unpin_after_use and self .group .cpu_dict_prepared :
205204 self .group .unpin_memory_ ()
206205 return output
207206
0 commit comments