@@ -101,7 +101,7 @@ def __init__(
101101 self .offload_to_disk_path = offload_to_disk_path
102102 self ._is_offloaded_to_disk = False
103103
104- if self .offload_to_disk_path :
104+ if self .offload_to_disk_path is not None :
105105 # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
106106 self .group_id = group_id if group_id is not None else str (id (self ))
107107 short_hash = _compute_group_hash (self .group_id )
@@ -121,6 +121,12 @@ def __init__(
121121 else :
122122 self .cpu_param_dict = self ._init_cpu_param_dict ()
123123
124+ self ._torch_accelerator_module = (
125+ getattr (torch , torch .accelerator .current_accelerator ().type )
126+ if hasattr (torch , "accelerator" )
127+ else torch .cuda
128+ )
129+
124130 def _init_cpu_param_dict (self ):
125131 cpu_param_dict = {}
126132 if self .stream is None :
@@ -144,16 +150,12 @@ def _init_cpu_param_dict(self):
144150
145151 @contextmanager
146152 def _pinned_memory_tensors (self ):
147- pinned_dict = {}
148153 try :
149- for param , tensor in self .cpu_param_dict .items ():
150- if not tensor .is_pinned ():
151- pinned_dict [param ] = tensor .pin_memory ()
152- else :
153- pinned_dict [param ] = tensor
154-
154+ pinned_dict = {
155+ param : tensor .pin_memory () if not tensor .is_pinned () else tensor
156+ for param , tensor in self .cpu_param_dict .items ()
157+ }
155158 yield pinned_dict
156-
157159 finally :
158160 pinned_dict = None
159161
@@ -179,77 +181,47 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
179181 source = pinned_memory [buffer ] if pinned_memory else buffer .data
180182 self ._transfer_tensor_to_device (buffer , source , current_stream )
181183
182- def _onload_from_disk (self , current_stream ):
184+ def _onload_from_disk (self ):
183185 if self .stream is not None :
184- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
185-
186- for key , tensor_obj in self .key_to_tensor .items ():
187- self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
188-
189- with self ._pinned_memory_tensors () as pinned_memory :
190- for key , tensor_obj in self .key_to_tensor .items ():
191- self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
192-
193- self .cpu_param_dict .clear ()
194-
195- else :
196- onload_device = (
197- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
198- )
199- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
200- for key , tensor_obj in self .key_to_tensor .items ():
201- tensor_obj .data = loaded_tensors [key ]
186+ # Wait for previous Host->Device transfer to complete
187+ self .stream .synchronize ()
202188
203- def _onload_from_memory (self , current_stream ):
204- if self .stream is not None :
205- with self ._pinned_memory_tensors () as pinned_memory :
206- self ._process_tensors_from_modules (pinned_memory , current_stream )
207- else :
208- self ._process_tensors_from_modules (None , current_stream )
189+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
190+ current_stream = self ._torch_accelerator_module .current_stream () if self .record_stream else None
209191
210- @torch .compiler .disable ()
211- def onload_ (self ):
212- torch_accelerator_module = (
213- getattr (torch , torch .accelerator .current_accelerator ().type )
214- if hasattr (torch , "accelerator" )
215- else torch .cuda
216- )
217- context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
218- current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
192+ with context :
193+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
194+ device = self .onload_device if self .stream is None else "cpu"
195+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = device )
219196
220- if self .offload_to_disk_path :
221197 if self .stream is not None :
222- # Wait for previous Host->Device transfer to complete
223- self .stream .synchronize ()
224-
225- with context :
226- if self .stream is not None :
227- # Load to CPU, pin, and async copy to device for overlapping transfer and compute
228- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
229- for key , tensor_obj in self .key_to_tensor .items ():
230- pinned_tensor = loaded_cpu_tensors [key ].pin_memory ()
231- tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
232- if self .record_stream :
233- tensor_obj .data .record_stream (current_stream )
234- else :
235- # Load directly to the target device (synchronous)
236- onload_device = (
237- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
238- )
239- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
240- for key , tensor_obj in self .key_to_tensor .items ():
241- tensor_obj .data = loaded_tensors [key ]
242- return
198+ for key , tensor_obj in self .key_to_tensor .items ():
199+ pinned_tensor = loaded_tensors [key ].pin_memory ()
200+ tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
201+ if self .record_stream :
202+ tensor_obj .data .record_stream (current_stream )
203+ else :
204+ onload_device = (
205+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
206+ )
207+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
208+ for key , tensor_obj in self .key_to_tensor .items ():
209+ tensor_obj .data = loaded_tensors [key ]
243210
211+ def _onload_from_memory (self ):
244212 if self .stream is not None :
245213 # Wait for previous Host->Device transfer to complete
246214 self .stream .synchronize ()
247215
216+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
217+ current_stream = self ._torch_accelerator_module .current_stream () if self .record_stream else None
218+
248219 with context :
249- if self .offload_to_disk_path :
250- self ._onload_from_disk (current_stream )
220+ if self .stream is not None :
221+ with self ._pinned_memory_tensors () as pinned_memory :
222+ self ._process_tensors_from_modules (pinned_memory , current_stream )
251223 else :
252- self ._onload_from_memory ( current_stream )
224+ self ._process_tensors_from_modules ( None , current_stream )
253225
254226 def _offload_to_disk (self ):
255227 # TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -270,14 +242,10 @@ def _offload_to_disk(self):
270242 tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
271243
272244 def _offload_to_memory (self ):
273- torch_accelerator_module = (
274- getattr (torch , torch .accelerator .current_accelerator ().type )
275- if hasattr (torch , "accelerator" )
276- else torch .cuda
277- )
278245 if self .stream is not None :
279246 if not self .record_stream :
280- torch_accelerator_module .current_stream ().synchronize ()
247+ self ._torch_accelerator_module .current_stream ().synchronize ()
248+
281249 for group_module in self .modules :
282250 for param in group_module .parameters ():
283251 param .data = self .cpu_param_dict [param ]
@@ -288,15 +256,23 @@ def _offload_to_memory(self):
288256
289257 else :
290258 for group_module in self .modules :
291- group_module .to (self .offload_device , non_blocking = self . non_blocking )
259+ group_module .to (self .offload_device , non_blocking = False )
292260 for param in self .parameters :
293- param .data = param .data .to (self .offload_device , non_blocking = self . non_blocking )
261+ param .data = param .data .to (self .offload_device , non_blocking = False )
294262 for buffer in self .buffers :
295- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
263+ buffer .data = buffer .data .to (self .offload_device , non_blocking = False )
264+
265+ @torch .compiler .disable ()
266+ def onload_ (self ):
267+ r"""Onloads the group of parameters to the onload_device."""
268+ if self .offload_to_disk_path is not None :
269+ self ._onload_from_disk ()
270+ else :
271+ self ._onload_from_memory ()
296272
297273 @torch .compiler .disable ()
298274 def offload_ (self ):
299- r"""Offloads the group of modules to the offload_device."""
275+ r"""Offloads the group of parameters to the offload_device."""
300276 if self .offload_to_disk_path :
301277 self ._offload_to_disk ()
302278 else :
@@ -462,8 +438,8 @@ def pre_forward(self, module, *args, **kwargs):
462438
463439def apply_group_offloading (
464440 module : torch .nn .Module ,
465- onload_device : torch .device ,
466- offload_device : torch .device = torch .device ("cpu" ),
441+ onload_device : Union [ str , torch .device ] ,
442+ offload_device : Union [ str , torch .device ] = torch .device ("cpu" ),
467443 offload_type : Union [str , GroupOffloadingType ] = "block_level" ,
468444 num_blocks_per_group : Optional [int ] = None ,
469445 non_blocking : bool = False ,
@@ -549,6 +525,8 @@ def apply_group_offloading(
549525 ```
550526 """
551527
528+ onload_device = torch .device (onload_device ) if isinstance (onload_device , str ) else onload_device
529+ offload_device = torch .device (offload_device ) if isinstance (offload_device , str ) else offload_device
552530 offload_type = GroupOffloadingType (offload_type )
553531
554532 stream = None
0 commit comments