@@ -135,9 +135,58 @@ def _pinned_memory_tensors(self):
135135 finally :
136136 pinned_dict = None
137137
138+ def _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
139+ tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
140+ if self .record_stream and current_stream is not None :
141+ tensor .data .record_stream (current_stream )
142+
143+ def _process_tensors_from_modules (self , pinned_memory = None , current_stream = None ):
144+ for group_module in self .modules :
145+ for param in group_module .parameters ():
146+ source = pinned_memory [param ] if pinned_memory else param .data
147+ self ._transfer_tensor_to_device (param , source , current_stream )
148+ for buffer in group_module .buffers ():
149+ source = pinned_memory [buffer ] if pinned_memory else buffer .data
150+ self ._transfer_tensor_to_device (buffer , source , current_stream )
151+
152+ for param in self .parameters :
153+ source = pinned_memory [param ] if pinned_memory else param .data
154+ self ._transfer_tensor_to_device (param , source , current_stream )
155+
156+ for buffer in self .buffers :
157+ source = pinned_memory [buffer ] if pinned_memory else buffer .data
158+ self ._transfer_tensor_to_device (buffer , source , current_stream )
159+
160+ def _onload_from_disk (self , current_stream ):
161+ if self .stream is not None :
162+ loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
163+
164+ for key , tensor_obj in self .key_to_tensor .items ():
165+ self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
166+
167+ with self ._pinned_memory_tensors () as pinned_memory :
168+ for key , tensor_obj in self .key_to_tensor .items ():
169+ self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
170+
171+ self .cpu_param_dict .clear ()
172+
173+ else :
174+ onload_device = (
175+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
176+ )
177+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
178+ for key , tensor_obj in self .key_to_tensor .items ():
179+ tensor_obj .data = loaded_tensors [key ]
180+
181+ def _onload_from_memory (self , current_stream ):
182+ if self .stream is not None :
183+ with self ._pinned_memory_tensors () as pinned_memory :
184+ self ._process_tensors_from_modules (pinned_memory , current_stream )
185+ else :
186+ self ._process_tensors_from_modules (None , current_stream )
187+
138188 @torch .compiler .disable ()
139189 def onload_ (self ):
140- r"""Onloads the group of modules to the onload_device."""
141190 torch_accelerator_module = (
142191 getattr (torch , torch .accelerator .current_accelerator ().type )
143192 if hasattr (torch , "accelerator" )
@@ -175,67 +224,32 @@ def onload_(self):
175224 self .stream .synchronize ()
176225
177226 with context :
178- if self .stream is not None :
179- with self ._pinned_memory_tensors () as pinned_memory :
180- for group_module in self .modules :
181- for param in group_module .parameters ():
182- param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
183- if self .record_stream :
184- param .data .record_stream (current_stream )
185- for buffer in group_module .buffers ():
186- buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
187- if self .record_stream :
188- buffer .data .record_stream (current_stream )
189-
190- for param in self .parameters :
191- param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
192- if self .record_stream :
193- param .data .record_stream (current_stream )
194-
195- for buffer in self .buffers :
196- buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
197- if self .record_stream :
198- buffer .data .record_stream (current_stream )
199-
227+ if self .offload_to_disk_path :
228+ self ._onload_from_disk (current_stream )
200229 else :
201- for group_module in self .modules :
202- for param in group_module .parameters ():
203- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
204- for buffer in group_module .buffers ():
205- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
206-
207- for param in self .parameters :
208- param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
209-
210- for buffer in self .buffers :
211- buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
212- if self .record_stream :
213- buffer .data .record_stream (current_stream )
230+ self ._onload_from_memory (current_stream )
214231
215232 @torch .compiler .disable ()
216- def offload_ (self ):
217- r"""Offloads the group of modules to the offload_device."""
218- if self .offload_to_disk_path :
219- # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220- # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221- # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222- # we perform a write.
223- # Check if the file has been saved in this session or if it already exists on disk.
224- if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
225- os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
226- tensors_to_save = {
227- key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()
228- }
229- safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
230-
231- # The group is now considered offloaded to disk for the rest of the session.
232- self ._is_offloaded_to_disk = True
233-
234- # We do this to free up the RAM which is still holding the up tensor data.
235- for tensor_obj in self .tensor_to_key .keys ():
236- tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
237- return
233+ def _offload_to_disk (self ):
234+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
235+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
236+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
237+ # we perform a write.
238+ # Check if the file has been saved in this session or if it already exists on disk.
239+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
240+ os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
241+ tensors_to_save = {key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()}
242+ safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
243+
244+ # The group is now considered offloaded to disk for the rest of the session.
245+ self ._is_offloaded_to_disk = True
246+
247+ # We do this to free up the RAM which is still holding the up tensor data.
248+ for tensor_obj in self .tensor_to_key .keys ():
249+ tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
238250
251+ @torch .compiler .disable ()
252+ def _offload_to_memory (self ):
239253 torch_accelerator_module = (
240254 getattr (torch , torch .accelerator .current_accelerator ().type )
241255 if hasattr (torch , "accelerator" )
@@ -260,6 +274,14 @@ def offload_(self):
260274 for buffer in self .buffers :
261275 buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
262276
277+ @torch .compiler .disable ()
278+ def offload_ (self ):
279+ r"""Offloads the group of modules to the offload_device."""
280+ if self .offload_to_disk_path :
281+ self ._offload_to_disk ()
282+ else :
283+ self ._offload_to_memory ()
284+
263285
264286class GroupOffloadingHook (ModelHook ):
265287 r"""
@@ -484,6 +506,8 @@ def apply_group_offloading(
484506 option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
485507 the CPU memory is a bottleneck but may counteract the benefits of using streams.
486508
509+ (TODO: include example with `offload_to_disk_path`)
510+
487511 Example:
488512 ```python
489513 >>> from diffusers import CogVideoXTransformer3DModel
0 commit comments