@@ -161,7 +161,7 @@ def _pinned_memory_tensors(self):
161161 finally :
162162 pinned_dict = None
163163
164- def _transfer_tensor_to_device (self , tensor , source_tensor , default_stream = None ):
164+ def _transfer_tensor_to_device (self , tensor , source_tensor , default_stream ):
165165 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
166166 if self .record_stream :
167167 tensor .data .record_stream (default_stream )
@@ -295,7 +295,11 @@ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None
295295 self .config = config
296296
297297 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
298- # For disk offload we materialize the safetensor files upfront so callers can inspect them immediately.
298+ # Disk offload only: materialize the safetensor files up front so they exist right after enable_group_offload.
299+ # Needed for flows/tests that inspect the offload dir before the first forward
300+ # eg: model.enable_group_offload(..., offload_to_disk_path=tmpdir)
301+ # assert glob.glob(f"{tmpdir}/*.safetensors")
302+ # In-memory offload stays lazy to allow adapter loading before the first forward.
299303 if self .group .offload_to_disk_path is not None and self .group .offload_leader == module :
300304 self .group .offload_ ()
301305 return module
@@ -305,18 +309,18 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
305309 # method is the onload_leader of the group.
306310 if self .group .onload_leader is None :
307311 self .group .onload_leader = module
308- is_leader = self . group . onload_leader == module
312+
309313 should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
310- should_orchestrate = self .group .pinned or is_leader
311314
312- if should_orchestrate :
313- # Pinned groups keep their params on the onload device; orchestrate onload/prefetch/sync every call.
315+ if self .group .onload_leader == module :
316+ # If the current module is the onload_leader of the group, we onload the group if it is supposed
317+ # to onload itself. In the case of using prefetching with streams, we onload the next group if
318+ # it is not supposed to onload itself.
314319 if self .group .pinned :
315- if is_leader and not self ._is_group_on_device ():
316- self .group .onload_ ()
317- else :
318- if is_leader and self .group .onload_self :
320+ if not self ._is_group_on_device ():
319321 self .group .onload_ ()
322+ elif self .group .onload_self :
323+ self .group .onload_ ()
320324
321325 if should_onload_next_group :
322326 self .next_group .onload_ ()
@@ -335,18 +339,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
335339 self .group .stream .synchronize ()
336340
337341 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
338- kwargs = self ._send_kwargs_to_device (kwargs )
339- return args , kwargs
340-
341- def post_forward (self , module : torch .nn .Module , output ):
342- if self .group .pinned :
343- return output
344-
345- if self .group .offload_leader == module :
346- self .group .offload_ ()
347- return output
348342
349- def _send_kwargs_to_device (self , kwargs ):
343+ # Some Autoencoder models use a feature cache that is passed through submodules and modified in place.
344+ # The `send_to_device` call returns a copy of this feature cache object which breaks the inplace updates.
345+ # Use `exclude_kwargs` to mark these cache features so they are not moved.
350346 exclude_kwargs = self .config .exclude_kwargs or []
351347 if exclude_kwargs :
352348 moved_kwargs = send_to_device (
@@ -355,8 +351,19 @@ def _send_kwargs_to_device(self, kwargs):
355351 non_blocking = self .group .non_blocking ,
356352 )
357353 kwargs .update (moved_kwargs )
358- return kwargs
359- return send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
354+ else :
355+ kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
356+
357+ return args , kwargs
358+
359+ def post_forward (self , module : torch .nn .Module , output ):
360+ # Pinned groups stay resident, otherwise offload when the offload leader finishes.
361+ if self .group .pinned :
362+ return output
363+
364+ if self .group .offload_leader == module :
365+ self .group .offload_ ()
366+ return output
360367
361368 def _is_group_on_device (self ) -> bool :
362369 tensors = []
@@ -535,6 +542,10 @@ def _validate_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional
535542 return pin_groups
536543 if isinstance (pin_groups , str ) and pin_groups in VALID_PIN_GROUPS :
537544 return pin_groups
545+ elif isinstance (pin_groups , str ) and pin_groups not in VALID_PIN_GROUPS :
546+ raise ValueError (
547+ f"`pin_groups` must be None, { ', ' .join (repr (v ) for v in sorted (VALID_PIN_GROUPS ))} , or a callable."
548+ )
538549 raise ValueError (
539550 f"`pin_groups` must be None, { ', ' .join (repr (v ) for v in sorted (VALID_PIN_GROUPS ))} , or a callable."
540551 )
0 commit comments