@@ -245,7 +245,6 @@ def _offload_to_memory(self):
245
245
param .data = self .cpu_param_dict [param ]
246
246
for buffer in self .buffers :
247
247
buffer .data = self .cpu_param_dict [buffer ]
248
-
249
248
else :
250
249
for group_module in self .modules :
251
250
group_module .to (self .offload_device , non_blocking = False )
@@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
303
302
if self .group .onload_leader == module :
304
303
if self .group .onload_self :
305
304
self .group .onload_ ()
306
- if self .next_group is not None and not self .next_group .onload_self :
305
+
306
+ should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
307
+ if should_onload_next_group :
307
308
self .next_group .onload_ ()
308
309
310
+ should_synchronize = (
311
+ not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
312
+ )
313
+ if should_synchronize :
314
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
315
+ # previous group. We need to synchronize the side stream to ensure parameters
316
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
317
+ # weights will be used in the computation, leading to incorrect results
318
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
319
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
320
+ self .group .stream .synchronize ()
321
+
309
322
args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
310
323
kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
311
324
return args , kwargs
0 commit comments