Skip to content

Commit 5ea3d8a

Browse files
committed
optimise
1 parent f30c55f commit 5ea3d8a

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def post_forward(self, module, output):
218218
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
219219

220220
for i in range(num_executed):
221-
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER)
221+
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
222222

223223
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
224-
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING)
224+
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
225225

226226
# Apply lazy prefetching by setting required attributes
227227
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
@@ -536,7 +536,10 @@ def _apply_lazy_group_offloading_hook(
536536
hook = GroupOffloadingHook(group, offload_on_init, next_group)
537537
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
538538
registry = HookRegistry.check_if_exists_or_initialize(module)
539-
registry.register_hook(hook, _GROUP_OFFLOADING)
539+
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
540+
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
541+
if registry.get_hook(_GROUP_OFFLOADING) is None:
542+
registry.register_hook(hook, _GROUP_OFFLOADING)
540543
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
541544

542545

src/diffusers/hooks/hooks.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import functools
16-
import gc
1716
from typing import Any, Dict, Optional, Tuple
1817

1918
import torch
@@ -187,8 +186,6 @@ def remove_hook(self, name: str, recurse: bool = True) -> None:
187186
if hasattr(module, "_diffusers_hook"):
188187
module._diffusers_hook.remove_hook(name, recurse=False)
189188

190-
gc.collect()
191-
192189
def reset_stateful_hooks(self, recurse: bool = True) -> None:
193190
for hook_name in reversed(self._hook_order):
194191
hook = self.hooks[hook_name]

0 commit comments

Comments
 (0)