Skip to content

Commit 46619ea

Browse files
committed
update
1 parent c76e1cc commit 46619ea

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/diffusers/hooks/first_block_cache.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def initialize_hook(self, module):
7777

7878
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
7979
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
80-
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
80+
81+
if isinstance(outputs_if_skipped, tuple):
82+
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
83+
else:
84+
original_hs = outputs_if_skipped
8185

8286
output = self.fn_ref.original_forward(*args, **kwargs)
8387
is_output_tuple = isinstance(output, tuple)
@@ -200,14 +204,14 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
200204
head_block_name, head_block = remaining_blocks.pop(0)
201205
tail_block_name, tail_block = remaining_blocks.pop(-1)
202206

203-
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
207+
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
204208
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
205209

206210
for name, block in remaining_blocks:
207-
logger.debug(f"Apply FBCBlockHook to '{name}'")
211+
logger.debug(f"Applying FBCBlockHook to '{name}'")
208212
apply_fbc_block_hook(block, shared_state)
209213

210-
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
214+
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
211215
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
212216

213217

0 commit comments

Comments
 (0)