Skip to content

Commit 2ff1716

Browse files
committed
update
1 parent c3d22f7 commit 2ff1716

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
if is_torch_available():
5-
from .enhance_a_video import EnhanceAVideoConfig, apply_enhance_a_video
5+
from .enhance_a_video import EnhanceAVideoConfig, apply_enhance_a_video, remove_enhance_a_video
66
from .group_offloading import apply_group_offloading
77
from .hooks import HookRegistry, ModelHook
88
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/enhance_a_video.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929

3030
_ENHANCE_A_VIDEO = "enhance_a_video"
31-
_ENHANCE_A_VIDEO_SDPA = "enhance_a_video_sdpa"
3231

3332

3433
class _AttentionType(Enum):
@@ -188,4 +187,13 @@ def apply_enhance_a_video(module: torch.nn.Module, config: EnhanceAVideoConfig)
188187
num_frames_callback=config.num_frames_callback,
189188
_attention_type=config._attention_type,
190189
)
191-
hook_registry.register_hook(hook, _ENHANCE_A_VIDEO_SDPA)
190+
hook_registry.register_hook(hook, _ENHANCE_A_VIDEO)
191+
192+
193+
def remove_enhance_a_video(module: torch.nn.Module) -> None:
194+
for name, submodule in module.named_modules():
195+
if not hasattr(submodule, "_diffusers_hook"):
196+
continue
197+
hook_registry = submodule._diffusers_hook
198+
hook_registry.remove_hook(_ENHANCE_A_VIDEO, recurse=False)
199+
logger.debug(f"Removed Enhance-A-Video from layer '{name}'")

0 commit comments

Comments
 (0)