Skip to content

Commit 6a9a3e5

Browse files
committed
non_blocking; handle parameters and buffers
1 parent 2783669 commit 6a9a3e5

File tree

2 files changed

+62
-38
lines changed

2 files changed

+62
-38
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
from .hooks import HookRegistry, ModelHook
2121

2222

23-
_TRANSFORMER_STACK_IDENTIFIERS = [
23+
_COMMON_STACK_IDENTIFIERS = {
2424
"transformer_blocks",
2525
"single_transformer_blocks",
2626
"temporal_transformer_blocks",
2727
"transformer_layers",
2828
"layers",
2929
"blocks",
30-
]
30+
"down_blocks",
31+
"up_blocks",
32+
"mid_blocks",
33+
}
3134

3235

3336
class ModuleGroup:
@@ -62,25 +65,16 @@ class GroupOffloadingHook(ModelHook):
6265
encounter such an error.
6366
"""
6467

65-
def __init__(self, group: ModuleGroup, offload_on_init: bool = True) -> None:
68+
def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None:
6669
self.group = group
6770
self.offload_on_init = offload_on_init
71+
self.non_blocking = non_blocking
6872

6973
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7074
if self.offload_on_init:
7175
self.offload_(module)
7276
return module
7377

74-
def onload_(self, module: torch.nn.Module) -> None:
75-
if self.group.onload_leader is not None and self.group.onload_leader == module:
76-
for group_module in self.group.modules:
77-
group_module.to(self.group.onload_device)
78-
79-
def offload_(self, module: torch.nn.Module) -> None:
80-
if self.group.offload_leader == module:
81-
for group_module in self.group.modules:
82-
group_module.to(self.group.offload_device)
83-
8478
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
8579
if self.group.onload_leader is None:
8680
self.group.onload_leader = module
@@ -91,6 +85,19 @@ def post_forward(self, module: torch.nn.Module, output):
9185
self.offload_(module)
9286
return output
9387

88+
def onload_(self, module: torch.nn.Module) -> None:
89+
if self.group.onload_leader == module:
90+
for group_module in self.group.modules:
91+
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)
92+
93+
def offload_(self, module: torch.nn.Module) -> None:
94+
if self.group.offload_leader == module:
95+
for group_module in self.group.modules:
96+
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
97+
# TODO: do we need to sync here because of GPU->CPU transfer?
98+
if self.non_blocking and self.group.offload_device.type == "cpu":
99+
torch.cpu.synchronize()
100+
94101

95102
def apply_group_offloading(
96103
module: torch.nn.Module,
@@ -99,14 +106,17 @@ def apply_group_offloading(
99106
offload_device: torch.device = torch.device("cpu"),
100107
onload_device: torch.device = torch.device("cuda"),
101108
force_offload: bool = True,
109+
non_blocking: bool = False,
102110
) -> None:
103111
if offload_group_patterns == "diffusers_block":
112+
if num_blocks_per_group is None:
113+
raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.")
104114
_apply_group_offloading_diffusers_block(
105-
module, num_blocks_per_group, offload_device, onload_device, force_offload
115+
module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking
106116
)
107117
else:
108118
_apply_group_offloading_group_patterns(
109-
module, offload_group_patterns, offload_device, onload_device, force_offload
119+
module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking
110120
)
111121

112122

@@ -116,26 +126,47 @@ def _apply_group_offloading_diffusers_block(
116126
offload_device: torch.device,
117127
onload_device: torch.device,
118128
force_offload: bool,
129+
non_blocking: bool,
119130
) -> None:
120-
if num_blocks_per_group is None:
121-
raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.")
122-
123-
for transformer_stack_identifier in _TRANSFORMER_STACK_IDENTIFIERS:
124-
if not hasattr(module, transformer_stack_identifier) or not isinstance(
125-
getattr(module, transformer_stack_identifier), torch.nn.ModuleList
131+
# Handle device offloading/onloading for unet/transformer stack modules
132+
for stack_identifier in _COMMON_STACK_IDENTIFIERS:
133+
if not hasattr(module, stack_identifier) or not isinstance(
134+
getattr(module, stack_identifier), torch.nn.ModuleList
126135
):
127136
continue
128137

129-
transformer_stack = getattr(module, transformer_stack_identifier)
130-
num_blocks = len(transformer_stack)
138+
stack = getattr(module, stack_identifier)
139+
num_blocks = len(stack)
131140

132141
for i in range(0, num_blocks, num_blocks_per_group):
133-
blocks = transformer_stack[i : i + num_blocks_per_group]
142+
blocks = stack[i : i + num_blocks_per_group]
134143
group = ModuleGroup(
135144
blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0]
136145
)
137146
should_offload = force_offload or i > 0
138-
_apply_group_offloading(group, should_offload)
147+
_apply_group_offloading(group, should_offload, non_blocking)
148+
149+
# Handle device offloading/onloading for non-stack modules
150+
for name, submodule in module.named_modules():
151+
name_split = name.split(".")
152+
if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1:
153+
# We only want the layers that are top-level in the module (encompass all the submodules)
154+
# for enabling offloading.
155+
continue
156+
layer_name = name_split[0]
157+
print(layer_name)
158+
if layer_name in _COMMON_STACK_IDENTIFIERS:
159+
continue
160+
group = ModuleGroup(
161+
[submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule
162+
)
163+
_apply_group_offloading(group, force_offload, non_blocking)
164+
165+
# Always keep parameters and buffers on onload_device
166+
for name, param in module.named_parameters(recurse=False):
167+
param.data = param.data.to(onload_device)
168+
for name, buffer in module.named_buffers(recurse=False):
169+
buffer.data = buffer.data.to(onload_device)
139170

140171

141172
def _apply_group_offloading_group_patterns(
@@ -144,6 +175,7 @@ def _apply_group_offloading_group_patterns(
144175
offload_device: torch.device,
145176
onload_device: torch.device,
146177
force_offload: bool,
178+
non_blocking: bool,
147179
) -> None:
148180
per_group_modules = []
149181
for i, offload_group_pattern in enumerate(offload_group_patterns):
@@ -174,11 +206,11 @@ def _apply_group_offloading_group_patterns(
174206
for group in per_group_modules:
175207
# TODO: handle offload leader correctly
176208
group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1])
177-
_apply_group_offloading(group, force_offload)
209+
_apply_group_offloading(group, force_offload, non_blocking)
178210

179211

180-
def _apply_group_offloading(group: ModuleGroup, offload_on_init) -> None:
212+
def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None:
181213
for module in group.modules:
182-
hook = GroupOffloadingHook(group, offload_on_init=offload_on_init)
214+
hook = GroupOffloadingHook(group, offload_on_init, non_blocking)
183215
registry = HookRegistry.check_if_exists_or_initialize(module)
184216
registry.register_hook(hook, "group_offloading")

src/diffusers/hooks/hooks.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,8 @@ def new_forward(module, *args, **kwargs):
131131
output = old_forward(*args, **kwargs)
132132
return hook.post_forward(module, output)
133133

134-
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
135-
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
136-
if "GraphModuleImpl" in str(type(self._module_ref)):
137-
self._module_ref.__class__.forward = functools.update_wrapper(
138-
functools.partial(new_forward, self._module_ref), old_forward
139-
)
140-
else:
141-
self._module_ref.forward = functools.update_wrapper(
142-
functools.partial(new_forward, self._module_ref), old_forward
143-
)
134+
new_forward = functools.update_wrapper(new_forward, old_forward)
135+
self._module_ref.forward = new_forward.__get__(self._module_ref)
144136

145137
self.hooks[name] = hook
146138
self._hook_order.append(name)

0 commit comments

Comments
 (0)