Skip to content

Commit 0af7498

Browse files
committed
fuck yeah
1 parent 8c6edb3 commit 0af7498

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def _offload_to_memory(self):
245245
param.data = self.cpu_param_dict[param]
246246
for buffer in self.buffers:
247247
buffer.data = self.cpu_param_dict[buffer]
248-
249248
else:
250249
for group_module in self.modules:
251250
group_module.to(self.offload_device, non_blocking=False)
@@ -303,8 +302,19 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
303302
if self.group.onload_leader == module:
304303
if self.group.onload_self:
305304
self.group.onload_()
306-
if self.next_group is not None and not self.next_group.onload_self:
305+
306+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307+
if should_onload_next_group:
307308
self.next_group.onload_()
309+
310+
should_synchronize = not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
311+
if should_synchronize:
312+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
313+
# previous group. We need to synchronize the side stream to ensure parameters
314+
# are completely loaded to proceed with forward pass.
315+
# Also, we should only do this synchronize if we don't already do it from the sync call in
316+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
317+
self.group.stream.synchronize()
308318

309319
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
310320
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)

0 commit comments

Comments
 (0)