1+ # Experimental hook for TaylorSeer cache
2+ # Supports Flux only for now
3+
4+ import torch
5+ from dataclasses import dataclass
6+ from typing import Callable
7+ from .hooks import ModelHook
8+ import math
9+ from ..models .attention import Attention
10+ from ..models .attention import AttentionModuleMixin
11+ from ._common import (
12+ _ATTENTION_CLASSES ,
13+ )
14+ from ..hooks import HookRegistry
15+
16+ _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
17+
18+ @dataclass
19+ class TaylorSeerCacheConfig :
20+ fresh_threshold : int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
21+ max_order : int = 1 # order of Taylor series expansion
22+ current_timestep_callback : Callable [[], int ] = None
23+
24+ class TaylorSeerState :
25+ def __init__ (self ):
26+ self .predict_counter : int = 1
27+ self .last_step : int = 1000
28+ self .taylor_factors : dict [int , torch .Tensor ] = {}
29+
30+ def reset (self ):
31+ self .predict_counter = 1
32+ self .last_step = 1000
33+ self .taylor_factors = {}
34+
35+ def update (self , features : torch .Tensor , current_step : int , max_order : int , refresh_threshold : int ):
36+ N = math .abs (current_step - self .last_step )
37+ # initialize the first order taylor factors
38+ new_taylor_factors = {0 : features }
39+ for i in range (max_order ):
40+ if (self .taylor_factors .get (i ) is not None ) and current_step > 1 :
41+ new_taylor_factors [i + 1 ] = (self .taylor_factors [i ] - new_taylor_factors [i ]) / N
42+ else :
43+ break
44+ self .taylor_factors = new_taylor_factors
45+ self .last_step = current_step
46+ self .predict_counter = (self .predict_counter + 1 ) % refresh_threshold
47+
48+ def predict (self , current_step : int , refresh_threshold : int ):
49+ k = current_step - self .last_step
50+ device = self .taylor_factors [0 ].device
51+ output = torch .zeros_like (self .taylor_factors [0 ], device = device )
52+ for i in range (len (self .taylor_factors )):
53+ output += self .taylor_factors [i ] * (k ** i ) / math .factorial (i )
54+ self .predict_counter = (self .predict_counter + 1 ) % refresh_threshold
55+ return output
56+
57+ class TaylorSeerAttentionCacheHook (ModelHook ):
58+ _is_stateful = True
59+
60+ def __init__ (self , fresh_threshold : int , max_order : int , current_timestep_callback : Callable [[], int ]):
61+ super ().__init__ ()
62+ self .fresh_threshold = fresh_threshold
63+ self .max_order = max_order
64+ self .current_timestep_callback = current_timestep_callback
65+
66+ def initialize_hook (self , module ):
67+ self .img_state = TaylorSeerState ()
68+ self .txt_state = TaylorSeerState ()
69+ self .ip_state = TaylorSeerState ()
70+ return module
71+
72+ def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
73+ current_step = self .current_timestep_callback ()
74+ assert current_step is not None , "timestep is required for TaylorSeerAttentionCacheHook"
75+ should_predict = self .img_state .predict_counter > 0
76+
77+ if not should_predict :
78+ attention_outputs = self .fn_ref .original_forward (* args , ** kwargs )
79+ if len (attention_outputs ) == 2 :
80+ attn_output , context_attn_output = attention_outputs
81+ self .img_state .update (attn_output , current_step , self .max_order , self .fresh_threshold )
82+ self .txt_state .update (context_attn_output , current_step , self .max_order , self .fresh_threshold )
83+ elif len (attention_outputs ) == 3 :
84+ attn_output , context_attn_output , ip_attn_output = attention_outputs
85+ self .img_state .update (attn_output , current_step , self .max_order , self .fresh_threshold )
86+ self .txt_state .update (context_attn_output , current_step , self .max_order , self .fresh_threshold )
87+ self .ip_state .update (ip_attn_output , current_step , self .max_order , self .fresh_threshold )
88+ else :
89+ attn_output = self .img_state .predict (current_step , self .fresh_threshold )
90+ context_attn_output = self .txt_state .predict (current_step , self .fresh_threshold )
91+ ip_attn_output = self .ip_state .predict (current_step , self .fresh_threshold )
92+ attention_outputs = (attn_output , context_attn_output , ip_attn_output )
93+ return attention_outputs
94+
95+ def reset_state (self , module : torch .nn .Module ) -> None :
96+ self .img_state .reset ()
97+ self .txt_state .reset ()
98+ self .ip_state .reset ()
99+ return module
100+
101+ def apply_taylorseer_cache (module : torch .nn .Module , config : TaylorSeerCacheConfig ):
102+ for name , submodule in module .named_modules ():
103+ if not isinstance (submodule , (* _ATTENTION_CLASSES , AttentionModuleMixin )):
104+ # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
105+ # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
106+ # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
107+ continue
108+ _apply_taylorseer_cache_on_attention_class (name , submodule , config )
109+
110+
111+ def _apply_taylorseer_cache_on_attention_class (name : str , module : Attention , config : TaylorSeerCacheConfig ):
112+ _apply_taylorseer_cache_hook (module , config )
113+
114+
115+ def _apply_taylorseer_cache_hook (module : Attention , config : TaylorSeerCacheConfig ):
116+ registry = HookRegistry .check_if_exists_or_initialize (module )
117+ hook = TaylorSeerAttentionCacheHook (config .fresh_threshold , config .max_order , config .current_timestep_callback )
118+ registry .register_hook (hook , _TAYLORSEER_ATTENTION_CACHE_HOOK )
0 commit comments