Skip to content

Commit 69cdc25

Browse files
a-r-r-o-wsayakpaul
andauthored
Fix group offloading synchronization bug for parameter-only GroupModule's (#12077)
* update * update * refactor * fuck yeah * make style * Update src/diffusers/hooks/group_offloading.py Co-authored-by: Sayak Paul <[email protected]> * Update src/diffusers/hooks/group_offloading.py --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent cfd6ec7 commit 69cdc25

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def _offload_to_memory(self):
245245
param.data = self.cpu_param_dict[param]
246246
for buffer in self.buffers:
247247
buffer.data = self.cpu_param_dict[buffer]
248-
249248
else:
250249
for group_module in self.modules:
251250
group_module.to(self.offload_device, non_blocking=False)
@@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
303302
if self.group.onload_leader == module:
304303
if self.group.onload_self:
305304
self.group.onload_()
306-
if self.next_group is not None and not self.next_group.onload_self:
305+
306+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307+
if should_onload_next_group:
307308
self.next_group.onload_()
308309

310+
should_synchronize = (
311+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312+
)
313+
if should_synchronize:
314+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
315+
# previous group. We need to synchronize the side stream to ensure parameters
316+
# are completely loaded to proceed with forward pass. Without this, uninitialized
317+
# weights will be used in the computation, leading to incorrect results
318+
# Also, we should only do this synchronization if we don't already do it from the sync call in
319+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
320+
self.group.stream.synchronize()
321+
309322
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
310323
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
311324
return args, kwargs

0 commit comments

Comments
 (0)