@@ -245,7 +245,6 @@ def _offload_to_memory(self):
245245 param .data = self .cpu_param_dict [param ]
246246 for buffer in self .buffers :
247247 buffer .data = self .cpu_param_dict [buffer ]
248-
249248 else :
250249 for group_module in self .modules :
251250 group_module .to (self .offload_device , non_blocking = False )
@@ -303,8 +302,19 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
303302 if self .group .onload_leader == module :
304303 if self .group .onload_self :
305304 self .group .onload_ ()
306- if self .next_group is not None and not self .next_group .onload_self :
305+
306+ should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
307+ if should_onload_next_group :
307308 self .next_group .onload_ ()
309+
310+ should_synchronize = not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
311+ if should_synchronize :
312+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
313+ # previous group. We need to synchronize the side stream to ensure parameters
314+ # are completely loaded to proceed with forward pass.
315+ # Also, we should only do this synchronize if we don't already do it from the sync call in
316+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
317+ self .group .stream .synchronize ()
308318
309319 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
310320 kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
0 commit comments