Skip to content

Commit 4996dfd

Browse files
committed
fix
1 parent 3c498ef commit 4996dfd

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

src/diffusers/models/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def reset_state(self, module):
114114
for hook in self.hooks:
115115
if hook._is_stateful:
116116
hook.reset_state(module)
117+
return module
117118

118119

119120
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:

src/diffusers/pipelines/fastercache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
526526

527527
state.iteration += 1
528528
output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states
529-
return output
529+
return module._diffusers_hook.post_forward(module, output)
530530

531531
def reset_state(self, module: nn.Module) -> nn.Module:
532532
module._fastercache_state.reset()

0 commit comments

Comments
 (0)