@@ -95,7 +95,7 @@ def __init__(
9595 self .offload_to_disk_path = offload_to_disk_path
9696 self ._is_offloaded_to_disk = False
9797
98- if self .offload_to_disk_path :
98+ if self .offload_to_disk_path is not None :
9999 # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100100 self .group_id = group_id if group_id is not None else str (id (self ))
101101 short_hash = _compute_group_hash (self .group_id )
@@ -115,6 +115,12 @@ def __init__(
115115 else :
116116 self .cpu_param_dict = self ._init_cpu_param_dict ()
117117
118+ self ._torch_accelerator_module = (
119+ getattr (torch , torch .accelerator .current_accelerator ().type )
120+ if hasattr (torch , "accelerator" )
121+ else torch .cuda
122+ )
123+
118124 def _init_cpu_param_dict (self ):
119125 cpu_param_dict = {}
120126 if self .stream is None :
@@ -138,112 +144,76 @@ def _init_cpu_param_dict(self):
138144
139145 @contextmanager
140146 def _pinned_memory_tensors (self ):
141- pinned_dict = {}
142147 try :
143- for param , tensor in self .cpu_param_dict .items ():
144- if not tensor .is_pinned ():
145- pinned_dict [param ] = tensor .pin_memory ()
146- else :
147- pinned_dict [param ] = tensor
148-
148+ pinned_dict = {
149+ param : tensor .pin_memory () if not tensor .is_pinned () else tensor
150+ for param , tensor in self .cpu_param_dict .items ()
151+ }
149152 yield pinned_dict
150-
151153 finally :
152154 pinned_dict = None
153155
154- def _transfer_tensor_to_device (self , tensor , source_tensor , current_stream = None ):
156+ def _transfer_tensor_to_device (self , tensor , source_tensor ):
155157 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
156- if self .record_stream and current_stream is not None :
157- tensor .data .record_stream (current_stream )
158+ if self .record_stream :
159+ tensor .data .record_stream (self . _torch_accelerator_module . current_stream () )
158160
159- def _process_tensors_from_modules (self , pinned_memory = None , current_stream = None ):
161+ def _process_tensors_from_modules (self , pinned_memory = None ):
160162 for group_module in self .modules :
161163 for param in group_module .parameters ():
162164 source = pinned_memory [param ] if pinned_memory else param .data
163- self ._transfer_tensor_to_device (param , source , current_stream )
165+ self ._transfer_tensor_to_device (param , source )
164166 for buffer in group_module .buffers ():
165167 source = pinned_memory [buffer ] if pinned_memory else buffer .data
166- self ._transfer_tensor_to_device (buffer , source , current_stream )
168+ self ._transfer_tensor_to_device (buffer , source )
167169
168170 for param in self .parameters :
169171 source = pinned_memory [param ] if pinned_memory else param .data
170- self ._transfer_tensor_to_device (param , source , current_stream )
172+ self ._transfer_tensor_to_device (param , source )
171173
172174 for buffer in self .buffers :
173175 source = pinned_memory [buffer ] if pinned_memory else buffer .data
174- self ._transfer_tensor_to_device (buffer , source , current_stream )
176+ self ._transfer_tensor_to_device (buffer , source )
175177
176- def _onload_from_disk (self , current_stream ):
178+ def _onload_from_disk (self ):
177179 if self .stream is not None :
178- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
179-
180- for key , tensor_obj in self .key_to_tensor .items ():
181- self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
182-
183- with self ._pinned_memory_tensors () as pinned_memory :
184- for key , tensor_obj in self .key_to_tensor .items ():
185- self ._transfer_tensor_to_device (tensor_obj , pinned_memory [tensor_obj ], current_stream )
186-
187- self .cpu_param_dict .clear ()
188-
189- else :
190- onload_device = (
191- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
192- )
193- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
194- for key , tensor_obj in self .key_to_tensor .items ():
195- tensor_obj .data = loaded_tensors [key ]
180+ # Wait for previous Host->Device transfer to complete
181+ self .stream .synchronize ()
196182
197- def _onload_from_memory (self , current_stream ):
198- if self .stream is not None :
199- with self ._pinned_memory_tensors () as pinned_memory :
200- self ._process_tensors_from_modules (pinned_memory , current_stream )
201- else :
202- self ._process_tensors_from_modules (None , current_stream )
183+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
184+ current_stream = self ._torch_accelerator_module .current_stream () if self .record_stream else None
203185
204- @torch .compiler .disable ()
205- def onload_ (self ):
206- torch_accelerator_module = (
207- getattr (torch , torch .accelerator .current_accelerator ().type )
208- if hasattr (torch , "accelerator" )
209- else torch .cuda
210- )
211- context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
212- current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
186+ with context :
187+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188+ device = str (self .onload_device ) if self .stream is None else "cpu"
189+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = device )
213190
214- if self .offload_to_disk_path :
215191 if self .stream is not None :
216- # Wait for previous Host->Device transfer to complete
217- self .stream .synchronize ()
218-
219- with context :
220- if self .stream is not None :
221- # Load to CPU, pin, and async copy to device for overlapping transfer and compute
222- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
223- for key , tensor_obj in self .key_to_tensor .items ():
224- pinned_tensor = loaded_cpu_tensors [key ].pin_memory ()
225- tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
226- if self .record_stream :
227- tensor_obj .data .record_stream (current_stream )
228- else :
229- # Load directly to the target device (synchronous)
230- onload_device = (
231- self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
232- )
233- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
234- for key , tensor_obj in self .key_to_tensor .items ():
235- tensor_obj .data = loaded_tensors [key ]
236- return
192+ for key , tensor_obj in self .key_to_tensor .items ():
193+ pinned_tensor = loaded_tensors [key ].pin_memory ()
194+ tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
195+ if self .record_stream :
196+ tensor_obj .data .record_stream (current_stream )
197+ else :
198+ onload_device = (
199+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
200+ )
201+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
202+ for key , tensor_obj in self .key_to_tensor .items ():
203+ tensor_obj .data = loaded_tensors [key ]
237204
205+ def _onload_from_memory (self ):
238206 if self .stream is not None :
239207 # Wait for previous Host->Device transfer to complete
240208 self .stream .synchronize ()
241209
210+ context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
242211 with context :
243- if self .offload_to_disk_path :
244- self ._onload_from_disk (current_stream )
212+ if self .stream is not None :
213+ with self ._pinned_memory_tensors () as pinned_memory :
214+ self ._process_tensors_from_modules (pinned_memory )
245215 else :
246- self ._onload_from_memory ( current_stream )
216+ self ._process_tensors_from_modules ( None )
247217
248218 def _offload_to_disk (self ):
249219 # TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,33 +234,36 @@ def _offload_to_disk(self):
264234 tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
265235
266236 def _offload_to_memory (self ):
267- torch_accelerator_module = (
268- getattr (torch , torch .accelerator .current_accelerator ().type )
269- if hasattr (torch , "accelerator" )
270- else torch .cuda
271- )
272237 if self .stream is not None :
273238 if not self .record_stream :
274- torch_accelerator_module .current_stream ().synchronize ()
239+ self ._torch_accelerator_module .current_stream ().synchronize ()
240+
275241 for group_module in self .modules :
276242 for param in group_module .parameters ():
277243 param .data = self .cpu_param_dict [param ]
278244 for param in self .parameters :
279245 param .data = self .cpu_param_dict [param ]
280246 for buffer in self .buffers :
281247 buffer .data = self .cpu_param_dict [buffer ]
282-
283248 else :
284249 for group_module in self .modules :
285- group_module .to (self .offload_device , non_blocking = self . non_blocking )
250+ group_module .to (self .offload_device , non_blocking = False )
286251 for param in self .parameters :
287- param .data = param .data .to (self .offload_device , non_blocking = self . non_blocking )
252+ param .data = param .data .to (self .offload_device , non_blocking = False )
288253 for buffer in self .buffers :
289- buffer .data = buffer .data .to (self .offload_device , non_blocking = self .non_blocking )
254+ buffer .data = buffer .data .to (self .offload_device , non_blocking = False )
255+
256+ @torch .compiler .disable ()
257+ def onload_ (self ):
258+ r"""Onloads the group of parameters to the onload_device."""
259+ if self .offload_to_disk_path is not None :
260+ self ._onload_from_disk ()
261+ else :
262+ self ._onload_from_memory ()
290263
291264 @torch .compiler .disable ()
292265 def offload_ (self ):
293- r"""Offloads the group of modules to the offload_device."""
266+ r"""Offloads the group of parameters to the offload_device."""
294267 if self .offload_to_disk_path :
295268 self ._offload_to_disk ()
296269 else :
@@ -307,11 +280,9 @@ class GroupOffloadingHook(ModelHook):
307280
308281 _is_stateful = False
309282
310- def __init__ (
311- self , group : ModuleGroup , next_group : Optional [ModuleGroup ] = None , * , config : GroupOffloadingConfig
312- ) -> None :
283+ def __init__ (self , group : ModuleGroup , * , config : GroupOffloadingConfig ) -> None :
313284 self .group = group
314- self .next_group = next_group
285+ self .next_group : Optional [ ModuleGroup ] = None
315286 self .config = config
316287
317288 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
@@ -331,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
331302 if self .group .onload_leader == module :
332303 if self .group .onload_self :
333304 self .group .onload_ ()
334- if self .next_group is not None and not self .next_group .onload_self :
305+
306+ should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
307+ if should_onload_next_group :
335308 self .next_group .onload_ ()
336309
310+ should_synchronize = (
311+ not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
312+ )
313+ if should_synchronize :
314+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
315+ # previous group. We need to synchronize the side stream to ensure parameters
316+ # are completely loaded to proceed with forward pass. Without this, uninitialized
317+ # weights will be used in the computation, leading to incorrect results
318+ # Also, we should only do this synchronization if we don't already do it from the sync call in
319+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
320+ self .group .stream .synchronize ()
321+
337322 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
338323 kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
339324 return args , kwargs
@@ -459,8 +444,8 @@ def pre_forward(self, module, *args, **kwargs):
459444
460445def apply_group_offloading (
461446 module : torch .nn .Module ,
462- onload_device : torch .device ,
463- offload_device : torch .device = torch .device ("cpu" ),
447+ onload_device : Union [ str , torch .device ] ,
448+ offload_device : Union [ str , torch .device ] = torch .device ("cpu" ),
464449 offload_type : Union [str , GroupOffloadingType ] = "block_level" ,
465450 num_blocks_per_group : Optional [int ] = None ,
466451 non_blocking : bool = False ,
@@ -546,6 +531,8 @@ def apply_group_offloading(
546531 ```
547532 """
548533
534+ onload_device = torch .device (onload_device ) if isinstance (onload_device , str ) else onload_device
535+ offload_device = torch .device (offload_device ) if isinstance (offload_device , str ) else offload_device
549536 offload_type = GroupOffloadingType (offload_type )
550537
551538 stream = None
@@ -633,7 +620,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
633620 # Apply group offloading hooks to the module groups
634621 for i , group in enumerate (matched_module_groups ):
635622 for group_module in group .modules :
636- _apply_group_offloading_hook (group_module , group , None , config = config )
623+ _apply_group_offloading_hook (group_module , group , config = config )
637624
638625 # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
639626 # when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +649,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
662649 group_id = f"{ module .__class__ .__name__ } _unmatched_group" ,
663650 )
664651 if config .stream is None :
665- _apply_group_offloading_hook (module , unmatched_group , None , config = config )
652+ _apply_group_offloading_hook (module , unmatched_group , config = config )
666653 else :
667- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
654+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
668655
669656
670657def _apply_group_offloading_leaf_level (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
@@ -693,7 +680,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
693680 onload_self = True ,
694681 group_id = name ,
695682 )
696- _apply_group_offloading_hook (submodule , group , None , config = config )
683+ _apply_group_offloading_hook (submodule , group , config = config )
697684 modules_with_group_offloading .add (name )
698685
699686 # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +727,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
740727 onload_self = True ,
741728 group_id = name ,
742729 )
743- _apply_group_offloading_hook (parent_module , group , None , config = config )
730+ _apply_group_offloading_hook (parent_module , group , config = config )
744731
745732 if config .stream is not None :
746733 # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +749,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
762749 onload_self = True ,
763750 group_id = _GROUP_ID_LAZY_LEAF ,
764751 )
765- _apply_lazy_group_offloading_hook (module , unmatched_group , None , config = config )
752+ _apply_lazy_group_offloading_hook (module , unmatched_group , config = config )
766753
767754
768755def _apply_group_offloading_hook (
769756 module : torch .nn .Module ,
770757 group : ModuleGroup ,
771- next_group : Optional [ModuleGroup ] = None ,
772758 * ,
773759 config : GroupOffloadingConfig ,
774760) -> None :
@@ -777,14 +763,13 @@ def _apply_group_offloading_hook(
777763 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
778764 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
779765 if registry .get_hook (_GROUP_OFFLOADING ) is None :
780- hook = GroupOffloadingHook (group , next_group , config = config )
766+ hook = GroupOffloadingHook (group , config = config )
781767 registry .register_hook (hook , _GROUP_OFFLOADING )
782768
783769
784770def _apply_lazy_group_offloading_hook (
785771 module : torch .nn .Module ,
786772 group : ModuleGroup ,
787- next_group : Optional [ModuleGroup ] = None ,
788773 * ,
789774 config : GroupOffloadingConfig ,
790775) -> None :
@@ -793,7 +778,7 @@ def _apply_lazy_group_offloading_hook(
793778 # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
794779 # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
795780 if registry .get_hook (_GROUP_OFFLOADING ) is None :
796- hook = GroupOffloadingHook (group , next_group , config = config )
781+ hook = GroupOffloadingHook (group , config = config )
797782 registry .register_hook (hook , _GROUP_OFFLOADING )
798783
799784 lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
0 commit comments