@@ -77,7 +77,11 @@ def initialize_hook(self, module):
7777
7878 def new_forward(self, module: torch.nn.Module, *args, **kwargs):
7979 outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
80- original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
80+
81+ if isinstance(outputs_if_skipped, tuple):
82+ original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
83+ else:
84+ original_hs = outputs_if_skipped
8185
8286 output = self.fn_ref.original_forward(*args, **kwargs)
8387 is_output_tuple = isinstance(output, tuple)
@@ -200,14 +204,14 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
200204 head_block_name, head_block = remaining_blocks.pop(0)
201205 tail_block_name, tail_block = remaining_blocks.pop(-1)
202206
203- logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
207+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
204208 apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
205209
206210 for name, block in remaining_blocks:
207- logger.debug(f"Apply FBCBlockHook to '{name}'")
211+ logger.debug(f"Applying FBCBlockHook to '{name}'")
208212 apply_fbc_block_hook(block, shared_state)
209213
210- logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
214+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
211215 apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
212216
213217
0 commit comments