Skip to content

Commit b850c75

Browse files
committed
update
1 parent 0bf0baf commit b850c75

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -448,18 +448,22 @@ def _apply_group_offloading_leaf_level(
448448
# of the module is called
449449
parameters = []
450450
buffers = []
451+
module_dict = dict(module.named_modules())
451452

452-
def gather_non_module_parameters_and_buffers(m: torch.nn.Module):
453-
if isinstance(m, _SUPPORTED_PYTORCH_LAYERS):
454-
return
455-
for parameter in m.parameters(recurse=False):
456-
parameters.append(parameter)
457-
for buffer in m.buffers(recurse=False):
458-
buffers.append(buffer)
459-
for submodule in m.children():
460-
gather_non_module_parameters_and_buffers(submodule)
453+
for name, parameter in module.named_parameters():
454+
atoms = name.split(".")
455+
parent_name = ".".join(atoms[:-1])
456+
if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS):
457+
continue
458+
parameters.append(parameter)
459+
460+
for name, buffer in module.named_buffers():
461+
atoms = name.split(".")
462+
parent_name = ".".join(atoms[:-1])
463+
if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS):
464+
continue
465+
buffers.append(buffer)
461466

462-
gather_non_module_parameters_and_buffers(module)
463467
unmatched_group = ModuleGroup(
464468
modules=[],
465469
offload_device=offload_device,

src/diffusers/hooks/hooks.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ class ModelHook:
3131

3232
_is_stateful = False
3333

34+
def __init__(self) -> None:
35+
self.fn_ref = None
36+
3437
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3538
r"""
3639
Hook that is executed when a model is initialized.
@@ -103,6 +106,7 @@ def __init__(self) -> None:
103106
self.pre_forward = None
104107
self.post_forward = None
105108
self.old_forward = None
109+
self.is_overwritten_forward = False
106110

107111

108112
class HookRegistry:
@@ -119,40 +123,36 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
119123
if name in self.hooks.keys():
120124
logger.warning(f"Hook with name {name} already exists, replacing it.")
121125

122-
forward = self._module_ref.forward
123-
124-
fn_ref = FunctionReference()
125-
fn_ref.pre_forward = hook.pre_forward
126-
fn_ref.post_forward = hook.post_forward
127-
fn_ref.old_forward = forward
128-
129126
self._module_ref = hook.initialize_hook(self._module_ref)
130127

131-
def create_new_forward(function_reference: FunctionReference):
128+
def create_new_forward(function_reference: FunctionReference, forward):
132129
def new_forward(module, *args, **kwargs):
133130
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
134-
output = function_reference.old_forward(*args, **kwargs)
131+
output = forward(*args, **kwargs)
135132
return function_reference.post_forward(module, output)
136133

137134
return new_forward
138135

139-
# if hasattr(hook, "new_forward"):
140-
# fn_ref.old_forward = hook.new_forward
136+
forward = self._module_ref.forward
141137

142-
# def new_forward(module, *args, **kwargs):
143-
# args, kwargs = hook.pre_forward(module, *args, **kwargs)
144-
# output = rewritten_forward(module, *args, **kwargs)
145-
# return hook.post_forward(module, output)
146-
# else:
138+
fn_ref = FunctionReference()
139+
fn_ref.pre_forward = hook.pre_forward
140+
fn_ref.post_forward = hook.post_forward
141+
fn_ref.old_forward = forward
147142

148-
# def new_forward(module, *args, **kwargs):
149-
# args, kwargs = hook.pre_forward(module, *args, **kwargs)
150-
# output = forward(*args, **kwargs)
151-
# return hook.post_forward(module, output)
143+
if hasattr(hook, "new_forward"):
144+
new_forward = hook.new_forward
145+
fn_ref.is_overwritten_forward = True
146+
else:
147+
new_forward = forward
148+
fn_ref.is_overwritten_forward = False
152149

153-
new_forward = create_new_forward(fn_ref)
154-
self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward)
150+
rewritten_forward = create_new_forward(fn_ref, new_forward)
151+
self._module_ref.forward = functools.update_wrapper(
152+
functools.partial(rewritten_forward, self._module_ref), forward
153+
)
155154

155+
hook.fn_ref = fn_ref
156156
self.hooks[name] = hook
157157
self._hook_order.append(name)
158158
self._fn_refs.append(fn_ref)
@@ -165,7 +165,6 @@ def remove_hook(self, name: str, recurse: bool = True) -> None:
165165
if name in self.hooks.keys():
166166
hook = self.hooks[name]
167167
index = self._hook_order.index(name)
168-
169168
fn_ref = self._fn_refs[index]
170169

171170
if index == num_hooks - 1:

0 commit comments

Comments
 (0)