@@ -82,38 +82,65 @@ def onload_(self):
8282 self .stream .synchronize ()
8383
8484 with context :
85- # Use direct per-parameter transfers rather than module-level transfers
86- # This gives us more control and potentially better memory management
87- for group_module in self .modules :
88- # Check if any parameter needs moving
89- if any (p .device != self .onload_device for p in group_module .parameters ()):
90- for param in group_module .parameters ():
91- if param .device != self .onload_device :
92- # Use direct CUDA transfer for each parameter
93- if self .onload_device .type == "cuda" :
94- param .data = param .data .cuda (self .onload_device .index ,
95- non_blocking = self .non_blocking )
85+ # Use the most efficient module-level transfer when possible
86+ # This approach mirrors how PyTorch handles full model transfers
87+ if self .modules :
88+ for group_module in self .modules :
89+ # Only onload if some parameters are not on the target device
90+ if any (p .device != self .onload_device for p in group_module .parameters ()):
91+ try :
92+ # Try the most efficient approach using _apply
93+ if hasattr (group_module , "_apply" ):
94+ # This is what module.to() uses internally
95+ def to_device (t ):
96+ if t .device != self .onload_device :
97+ if self .onload_device .type == "cuda" :
98+ return t .cuda (self .onload_device .index ,
99+ non_blocking = self .non_blocking )
100+ else :
101+ return t .to (self .onload_device ,
102+ non_blocking = self .non_blocking )
103+ return t
104+
105+ # Apply to all tensors without unnecessary copies
106+ group_module ._apply (to_device )
96107 else :
97- param .data = param .data .to (self .onload_device ,
98- non_blocking = self .non_blocking )
99-
108+ # Fallback to direct parameter transfer
109+ for param in group_module .parameters ():
110+ if param .device != self .onload_device :
111+ if self .onload_device .type == "cuda" :
112+ param .data = param .data .cuda (self .onload_device .index ,
113+ non_blocking = self .non_blocking )
114+ else :
115+ param .data = param .data .to (self .onload_device ,
116+ non_blocking = self .non_blocking )
117+ except Exception as e :
118+ # If optimization fails, fall back to direct parameter transfer
119+ logger .warning (f"Optimized onloading failed: { e } , falling back to direct method" )
120+ for param in group_module .parameters ():
121+ if param .device != self .onload_device :
122+ if self .onload_device .type == "cuda" :
123+ param .data = param .data .cuda (self .onload_device .index ,
124+ non_blocking = self .non_blocking )
125+ else :
126+ param .data = param .data .to (self .onload_device ,
127+ non_blocking = self .non_blocking )
128+
100129 # Handle explicit parameters
101130 if self .parameters is not None :
102131 for param in self .parameters :
103132 if param .device != self .onload_device :
104- # Use direct CUDA transfer for each parameter
105133 if self .onload_device .type == "cuda" :
106134 param .data = param .data .cuda (self .onload_device .index ,
107135 non_blocking = self .non_blocking )
108136 else :
109137 param .data = param .data .to (self .onload_device ,
110138 non_blocking = self .non_blocking )
111-
139+
112140 # Handle buffers
113141 if self .buffers is not None :
114142 for buffer in self .buffers :
115143 if buffer .device != self .onload_device :
116- # Use direct CUDA transfer for each buffer
117144 if self .onload_device .type == "cuda" :
118145 buffer .data = buffer .data .cuda (self .onload_device .index ,
119146 non_blocking = self .non_blocking )
@@ -133,29 +160,58 @@ def offload_(self):
133160 if torch .cuda .is_available ():
134161 torch .cuda .empty_cache ()
135162
136- # Instead of using to() on the whole module which might create copies,
137- # directly move each parameter's data to CPU with cpu() which uses
138- # the memory-optimized path
139- for group_module in self .modules :
140- # Check if any parameter needs moving
141- if any (p .device .type != "cpu" for p in group_module .parameters ()):
142- for param in group_module .parameters ():
143- if param .device .type != "cpu" :
144- # Use direct cpu() method which is more memory-efficient than to()
145- param .data = param .data .cpu ()
163+ # For most memory-efficient CPU offloading, let's use a special approach
164+ # that simulates a full model device transfer:
165+ # 1. We'll minimize RAM usage by avoiding both unnecessary copies and
166+ # the accumulation of wasted memory over time
167+
168+ # First, for module groups, look for the highest-level module and offload at that level
169+ if self .modules :
170+ # For each root module in the group
171+ for group_module in self .modules :
172+ # Only offload if some parameters are not on CPU
173+ if any (p .device .type != "cpu" for p in group_module .parameters ()):
174+ try :
175+ # Try the lowest possible CPU memory approach - this works like model.to("cpu")
176+ # but at the module level
177+ if hasattr (group_module , "_apply" ):
178+ # This internal PyTorch method is what to() uses but with less overhead
179+ def cpu_tensor (t ):
180+ if t .device .type != "cpu" :
181+ return t .cpu ()
182+ return t
183+
184+ # Apply to all tensors in the module without unnecessary copies
185+ group_module ._apply (cpu_tensor )
186+ else :
187+ # Fallback to the direct method
188+ for param in group_module .parameters ():
189+ if param .device .type != "cpu" :
190+ param .data = param .data .cpu ()
191+ except Exception as e :
192+ # If for any reason the optimized approach fails, fall back to direct method
193+ logger .warning (f"Optimized CPU offloading failed: { e } , falling back to direct method" )
194+ for param in group_module .parameters ():
195+ if param .device .type != "cpu" :
196+ param .data = param .data .cpu ()
146197
147198 # Handle explicit parameters - move directly to CPU
148199 if self .parameters is not None :
149200 for param in self .parameters :
150201 if param .device .type != "cpu" :
151- # Direct CPU transfer with cpu() method
202+ # Direct CPU transfer
152203 param .data = param .data .cpu ()
153204
154205 # Handle buffers - move directly to CPU
155206 if self .buffers is not None :
156207 for buffer in self .buffers :
157208 if buffer .device .type != "cpu" :
158209 buffer .data = buffer .data .cpu ()
210+
211+ # Force garbage collection to clean up any released memory
212+ import gc
213+ gc .collect ()
214+
159215 else :
160216 # For non-CPU offloading, synchronize if using stream
161217 if self .stream is not None :
0 commit comments