2727from .hooks import HookRegistry , ModelHook
2828
2929
30+ VALID_PIN_GROUPS = {"all" , "first_last" }
31+
32+
3033if is_accelerate_available ():
3134 from accelerate .hooks import AlignDevicesHook , CpuOffload
3235 from accelerate .utils import send_to_device
@@ -302,36 +305,19 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
302305 # method is the onload_leader of the group.
303306 if self .group .onload_leader is None :
304307 self .group .onload_leader = module
308+ is_leader = self .group .onload_leader == module
309+ should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
310+ should_orchestrate = self .group .pinned or is_leader
311+
312+ if should_orchestrate :
313+ # Pinned groups keep their params on the onload device; orchestrate onload/prefetch/sync every call.
314+ if self .group .pinned :
315+ if is_leader and not self ._is_group_on_device ():
316+ self .group .onload_ ()
317+ else :
318+ if is_leader and self .group .onload_self :
319+ self .group .onload_ ()
305320
306- if self .group .pinned :
307- if self .group .onload_leader == module and not self ._is_group_on_device ():
308- self .group .onload_ ()
309-
310- should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
311- if should_onload_next_group :
312- self .next_group .onload_ ()
313-
314- should_synchronize = (
315- not self .group .onload_self
316- and self .group .stream is not None
317- and not should_onload_next_group
318- and not self .group .record_stream
319- )
320- if should_synchronize :
321- self .group .stream .synchronize ()
322-
323- args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
324- kwargs = self ._send_kwargs_to_device (kwargs )
325- return args , kwargs
326-
327- # If the current module is the onload_leader of the group, we onload the group if it is supposed
328- # to onload itself. In the case of using prefetching with streams, we onload the next group if
329- # it is not supposed to onload itself.
330- if self .group .onload_leader == module :
331- if self .group .onload_self :
332- self .group .onload_ ()
333-
334- should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
335321 if should_onload_next_group :
336322 self .next_group .onload_ ()
337323
@@ -345,9 +331,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
345331 # If this group didn't onload itself, it means it was asynchronously onloaded by the
346332 # previous group. We need to synchronize the side stream to ensure parameters
347333 # are completely loaded to proceed with forward pass. Without this, uninitialized
348- # weights will be used in the computation, leading to incorrect results
349- # Also, we should only do this synchronization if we don't already do it from the sync call in
350- # self.next_group.onload_, hence the `not should_onload_next_group` check.
334+ # weights will be used in the computation, leading to incorrect results.
351335 self .group .stream .synchronize ()
352336
353337 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
@@ -546,9 +530,6 @@ def pre_forward(self, module, *args, **kwargs):
546530 return args , kwargs
547531
548532
549- VALID_PIN_GROUPS = {"all" , "first_last" }
550-
551-
552533def _validate_pin_groups (pin_groups : Optional [Union [str , Callable ]]) -> Optional [Union [str , Callable ]]:
553534 if pin_groups is None or callable (pin_groups ):
554535 return pin_groups
@@ -708,9 +689,6 @@ def apply_group_offloading(
708689
709690
710691def _apply_group_offloading (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
711- registry = HookRegistry .check_if_exists_or_initialize (module )
712- registry ._group_offload_pin_groups = config .pin_groups
713-
714692 if config .offload_type == GroupOffloadingType .BLOCK_LEVEL :
715693 _apply_group_offloading_block_level (module , config )
716694 elif config .offload_type == GroupOffloadingType .LEAF_LEVEL :
0 commit comments