11import torch
22from dataclasses import dataclass
3- from typing import Callable , Optional , List , Dict , Tuple
3+ from typing import Optional , List , Dict , Tuple
44from .hooks import ModelHook
55import math
66from ..models .attention import Attention
1111from ..hooks import HookRegistry
1212from ..utils import logging
1313import re
14- from collections import defaultdict
1514
1615
17- logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
18-
16+ logger = logging .get_logger (__name__ )
1917
2018_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
2119
@@ -70,6 +68,7 @@ def __repr__(self) -> str:
7068 def get_identifiers_template (self ) -> Dict [str , Dict [str , List [str ]]]:
7169 return _CACHE_TEMPLATES
7270
71+
7372class TaylorSeerOutputState :
7473 """
7574 Manages the state for Taylor series-based prediction of a single attention output.
@@ -219,9 +218,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
219218 module_dtype = attention_outputs [0 ].dtype
220219 self .num_outputs = len (attention_outputs )
221220 self .states = [
222- TaylorSeerOutputState (
223- self .module_name , self .taylor_factors_dtype , module_dtype , is_skip = self .is_skip
224- )
221+ TaylorSeerOutputState (self .module_name , self .taylor_factors_dtype , module_dtype , is_skip = self .is_skip )
225222 for _ in range (self .num_outputs )
226223 ]
227224 for i , features in enumerate (attention_outputs ):
@@ -249,7 +246,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
249246 attention_outputs = list (attention_outputs )
250247 is_first_update = self .step_counter == 0 # Only True for the very first step
251248 for i , features in enumerate (attention_outputs ):
252- self .states [i ].update (features , self .step_counter , self .max_order , self .predict_steps , is_first_update )
249+ self .states [i ].update (
250+ features , self .step_counter , self .max_order , self .predict_steps , is_first_update
251+ )
253252 return attention_outputs [0 ] if self .num_outputs == 1 else tuple (attention_outputs )
254253 else :
255254 # Predict using Taylor series
@@ -330,4 +329,4 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
330329 is_skip = is_skip ,
331330 )
332331
333- registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
332+ registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
0 commit comments