Skip to content

Commit 01c7d22

Browse files
committed
more workarounds to make it actually work
1 parent d2a2981 commit 01c7d22

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def onload_(self, module: torch.nn.Module) -> None:
116116
if self.group.buffers is not None:
117117
for buffer in self.group.buffers:
118118
buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking)
119+
if self.onload_self:
120+
torch.cuda.synchronize()
119121

120122
def offload_(self, module: torch.nn.Module) -> None:
121123
if self.group.offload_leader == module:
@@ -388,7 +390,8 @@ def _apply_group_offloading_group_patterns(
388390
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules):
389391
buffers.append(buffer)
390392

391-
unmatched_modules = [module for _, module in unmatched_group_modules]
393+
ignore_blocks = ["transformer_blocks", "single_transformer_blocks", "temporal_transformer_blocks", "blocks"]
394+
unmatched_modules = [module for name, module in unmatched_group_modules if name not in ignore_blocks]
392395
unmatched_group = ModuleGroup(
393396
unmatched_modules,
394397
offload_device,

0 commit comments

Comments
 (0)