Skip to content

Commit 5535fd6

Browse files
committed
fix remove hook behaviour
1 parent 3d269ad commit 5535fd6

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

src/diffusers/hooks/hooks.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,28 +177,25 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
177177
return self.hooks.get(name, None)
178178

179179
def remove_hook(self, name: str, recurse: bool = True) -> None:
180-
if name not in self.hooks.keys():
181-
logger.warning(f"hook: {name} was not found in HookRegistry")
182-
return
183-
184-
num_hooks = len(self._hook_order)
185-
hook = self.hooks[name]
186-
index = self._hook_order.index(name)
187-
fn_ref = self._fn_refs[index]
188-
189-
old_forward = fn_ref.forward
190-
if fn_ref.original_forward is not None:
191-
old_forward = fn_ref.original_forward
192-
193-
if index == num_hooks - 1:
194-
self._module_ref.forward = old_forward
195-
else:
196-
self._fn_refs[index + 1].forward = old_forward
197-
198-
self._module_ref = hook.deinitalize_hook(self._module_ref)
199-
del self.hooks[name]
200-
self._hook_order.pop(index)
201-
self._fn_refs.pop(index)
180+
if name in self.hooks.keys():
181+
num_hooks = len(self._hook_order)
182+
hook = self.hooks[name]
183+
index = self._hook_order.index(name)
184+
fn_ref = self._fn_refs[index]
185+
186+
old_forward = fn_ref.forward
187+
if fn_ref.original_forward is not None:
188+
old_forward = fn_ref.original_forward
189+
190+
if index == num_hooks - 1:
191+
self._module_ref.forward = old_forward
192+
else:
193+
self._fn_refs[index + 1].forward = old_forward
194+
195+
self._module_ref = hook.deinitalize_hook(self._module_ref)
196+
del self.hooks[name]
197+
self._hook_order.pop(index)
198+
self._fn_refs.pop(index)
202199

203200
if recurse:
204201
for module_name, module in self._module_ref.named_modules():

0 commit comments

Comments
 (0)