@@ -120,7 +120,7 @@ def _process_output(self, output: Any) -> Any:
120120 """Process the output from the original component.
121121
122122 This method intercepts the output to create attention patterns
123- the same way as the old implementation.
123+ the same way as the old implementation and applies hook_out .
124124
125125 Args:
126126 output: Raw output from the original component
@@ -145,6 +145,17 @@ def _process_output(self, output: Any) -> Any:
145145
146146 # Apply the pattern to the output if needed
147147 output = self ._apply_pattern_to_output (output , pattern )
148+ else :
149+ # If no attention scores found, still apply hooks to the output
150+ if isinstance (output , tuple ):
151+ output = self ._process_tuple_output (output )
152+ elif isinstance (output , dict ):
153+ output = self ._process_dict_output (output )
154+ else :
155+ output = self ._process_single_output (output )
156+
157+ # Always apply hook_out to the main output
158+ output = self ._apply_hook_out_to_output (output )
148159
149160 return output
150161
@@ -219,12 +230,6 @@ def _apply_pattern_to_tuple_output(
219230 element = pattern
220231 processed_output .append (element )
221232
222- # Apply the main hook_out to the first element (hidden states) if it exists
223- if len (processed_output ) > 0 and processed_output [0 ] is not None :
224- processed_output [0 ] = self ._apply_hook_preserving_structure (
225- processed_output [0 ], self .hook_out
226- )
227-
228233 return tuple (processed_output )
229234
230235 def _apply_pattern_to_dict_output (
@@ -251,13 +256,6 @@ def _apply_pattern_to_dict_output(
251256 value = pattern
252257 processed_output [key ] = value
253258
254- # Apply hook_hidden_states and hook_out to the main output (usually hidden_states)
255- main_key = next ((k for k in output .keys () if "hidden" in k .lower ()), None )
256- if main_key and main_key in processed_output :
257- processed_output [main_key ] = self ._apply_hook_preserving_structure (
258- processed_output [main_key ], self .hook_out
259- )
260-
261259 return processed_output
262260
263261 def _apply_pattern_to_single_output (
@@ -276,7 +274,6 @@ def _apply_pattern_to_single_output(
276274 output = self ._apply_hook_preserving_structure (output , self .hook_hidden_states )
277275 # Apply the pattern to the output
278276 output = self ._apply_pattern_to_hidden_states (output , pattern )
279- output = self ._apply_hook_preserving_structure (output , self .hook_out )
280277 return output
281278
282279 def _apply_pattern_to_hidden_states (
@@ -325,12 +322,6 @@ def _process_tuple_output(self, output: Tuple[Any, ...]) -> Tuple[Any, ...]:
325322 element = self ._apply_hook_preserving_structure (element , self .hook_pattern )
326323 processed_output .append (element )
327324
328- # Apply the main hook_out to the first element (hidden states) if it exists
329- if len (processed_output ) > 0 and processed_output [0 ] is not None :
330- processed_output [0 ] = self ._apply_hook_preserving_structure (
331- processed_output [0 ], self .hook_out
332- )
333-
334325 return tuple (processed_output )
335326
336327 def _process_dict_output (self , output : Dict [str , Any ]) -> Dict [str , Any ]:
@@ -351,13 +342,6 @@ def _process_dict_output(self, output: Dict[str, Any]) -> Dict[str, Any]:
351342 value = self ._apply_hook_preserving_structure (value , self .hook_pattern )
352343 processed_output [key ] = value
353344
354- # Apply hook_hidden_states and hook_out to the main output (usually hidden_states)
355- main_key = next ((k for k in output .keys () if "hidden" in k .lower ()), None )
356- if main_key and main_key in processed_output :
357- processed_output [main_key ] = self ._apply_hook_preserving_structure (
358- processed_output [main_key ], self .hook_out
359- )
360-
361345 return processed_output
362346
363347 def _process_single_output (self , output : torch .Tensor ) -> torch .Tensor :
@@ -371,7 +355,6 @@ def _process_single_output(self, output: torch.Tensor) -> torch.Tensor:
371355 """
372356 # Apply hooks for single tensor output
373357 output = self ._apply_hook_preserving_structure (output , self .hook_hidden_states )
374- output = self ._apply_hook_preserving_structure (output , self .hook_out )
375358 return output
376359
377360 def _apply_hook_preserving_structure (self , element : Any , hook_fn ) -> Any :
@@ -395,6 +378,34 @@ def _apply_hook_preserving_structure(self, element: Any, hook_fn) -> Any:
395378 else :
396379 return element
397380
381+ def _apply_hook_out_to_output (self , output : Any ) -> Any :
382+ """Apply hook_out to the main output tensor.
383+
384+ Args:
385+ output: The output to process (can be tensor, tuple, or dict)
386+
387+ Returns:
388+ The output with hook_out applied to the main tensor
389+ """
390+ if isinstance (output , torch .Tensor ):
391+ return self .hook_out (output )
392+ elif isinstance (output , tuple ) and len (output ) > 0 :
393+ # Apply hook_out to the first element (typically hidden states)
394+ processed_tuple = list (output )
395+ if isinstance (output [0 ], torch .Tensor ):
396+ processed_tuple [0 ] = self .hook_out (output [0 ])
397+ return tuple (processed_tuple )
398+ elif isinstance (output , dict ):
399+ # Apply hook_out to the main hidden states in dictionary
400+ processed_dict = output .copy ()
401+ for key in ["last_hidden_state" , "hidden_states" ]:
402+ if key in processed_dict and isinstance (processed_dict [key ], torch .Tensor ):
403+ processed_dict [key ] = self .hook_out (processed_dict [key ])
404+ break # Only apply to the first found key
405+ return processed_dict
406+ else :
407+ return output
408+
398409 def forward (self , * args : Any , ** kwargs : Any ) -> Any :
399410 """Forward pass through the attention layer.
400411
0 commit comments