@@ -150,7 +150,7 @@ def to_device(t):
150150
151151 def offload_ (self ):
152152 r"""Offloads the group of modules to the offload_device."""
153- # For CPU offloading, use the most memory-efficient approach possible
153+ # For CPU offloading
154154 if self .offload_device .type == "cpu" :
155155 # Synchronize if using stream
156156 if self .stream is not None :
@@ -160,53 +160,38 @@ def offload_(self):
160160 if torch .cuda .is_available ():
161161 torch .cuda .empty_cache ()
162162
163- # For most memory-efficient CPU offloading, let's use a special approach
164- # that simulates a full model device transfer:
165- # 1. We'll minimize RAM usage by avoiding both unnecessary copies and
166- # the accumulation of wasted memory over time
167-
168- # First, for module groups, look for the highest-level module and offload at that level
163+ # For module groups, use a single, unified approach that is closest to
164+ # the behavior of model.to("cpu")
169165 if self .modules :
170- # For each root module in the group
171166 for group_module in self .modules :
172- # Only offload if some parameters are not on CPU
167+ # Check if we need to offload this module
173168 if any (p .device .type != "cpu" for p in group_module .parameters ()):
169+ # Use PyTorch's built-in to() method directly, which preserves
170+ # memory mapping when moving to CPU
174171 try :
175- # Try the lowest possible CPU memory approach - this works like model.to("cpu")
176- # but at the module level
177- if hasattr (group_module , "_apply" ):
178- # This internal PyTorch method is what to() uses but with less overhead
179- def cpu_tensor (t ):
180- if t .device .type != "cpu" :
181- return t .cpu ()
182- return t
183-
184- # Apply to all tensors in the module without unnecessary copies
185- group_module ._apply (cpu_tensor )
186- else :
187- # Fallback to the direct method
188- for param in group_module .parameters ():
189- if param .device .type != "cpu" :
190- param .data = param .data .cpu ()
172+ # Non-blocking=False for CPU transfers, as it ensures memory is
173+ # immediately available and potentially preserves memory mapping
174+ group_module .to ("cpu" , non_blocking = False )
191175 except Exception as e :
192- # If for any reason the optimized approach fails , fall back to direct method
193- logger .warning (f"Optimized CPU offloading failed: { e } , falling back to direct method " )
176+ # If there's any error , fall back to parameter-level offloading
177+ logger .warning (f"Module-level CPU offloading failed: { e } , falling back to parameter-level " )
194178 for param in group_module .parameters ():
195179 if param .device .type != "cpu" :
196- param .data = param .data .cpu ()
197-
198- # Handle explicit parameters - move directly to CPU
180+ param .data = param .data .to ("cpu" , non_blocking = False )
181+
182+ # Handle explicit parameters - move directly to CPU with non-blocking=False
183+ # which can preserve memory mapping in some PyTorch versions
199184 if self .parameters is not None :
200185 for param in self .parameters :
201186 if param .device .type != "cpu" :
202- # Direct CPU transfer
203- param .data = param .data .cpu ()
187+ param .data = param .data .to ("cpu" , non_blocking = False )
204188
205- # Handle buffers - move directly to CPU
189+ # Handle buffers
206190 if self .buffers is not None :
207191 for buffer in self .buffers :
208192 if buffer .device .type != "cpu" :
209- buffer .data = buffer .data .cpu ()
193+ buffer .data = buffer .data .to ("cpu" , non_blocking = False )
194+
210195 # Let Python's normal reference counting handle cleanup
211196 # We don't force garbage collection to avoid slowing down inference
212197
0 commit comments