Skip to content

Commit a4bfa45

Browse files
author
toilaluan
committed
init taylor_seer cache
1 parent d8e4805 commit a4bfa45

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Experimental hook for TaylorSeer cache
2+
# Supports Flux only for now
3+
4+
import torch
5+
from dataclasses import dataclass
6+
from typing import Callable
7+
from .hooks import ModelHook
8+
import math
9+
from ..models.attention import Attention
10+
from ..models.attention import AttentionModuleMixin
11+
from ._common import (
12+
_ATTENTION_CLASSES,
13+
)
14+
from ..hooks import HookRegistry
15+
16+
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
17+
18+
@dataclass
19+
class TaylorSeerCacheConfig:
20+
fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
21+
max_order: int = 1 # order of Taylor series expansion
22+
current_timestep_callback: Callable[[], int] = None
23+
24+
class TaylorSeerState:
25+
def __init__(self):
26+
self.predict_counter: int = 1
27+
self.last_step: int = 1000
28+
self.taylor_factors: dict[int, torch.Tensor] = {}
29+
30+
def reset(self):
31+
self.predict_counter = 1
32+
self.last_step = 1000
33+
self.taylor_factors = {}
34+
35+
def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int):
36+
N = math.abs(current_step - self.last_step)
37+
# initialize the first order taylor factors
38+
new_taylor_factors = {0: features}
39+
for i in range(max_order):
40+
if (self.taylor_factors.get(i) is not None) and current_step > 1:
41+
new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N
42+
else:
43+
break
44+
self.taylor_factors = new_taylor_factors
45+
self.last_step = current_step
46+
self.predict_counter = (self.predict_counter + 1) % refresh_threshold
47+
48+
def predict(self, current_step: int, refresh_threshold: int):
49+
k = current_step - self.last_step
50+
device = self.taylor_factors[0].device
51+
output = torch.zeros_like(self.taylor_factors[0], device=device)
52+
for i in range(len(self.taylor_factors)):
53+
output += self.taylor_factors[i] * (k ** i) / math.factorial(i)
54+
self.predict_counter = (self.predict_counter + 1) % refresh_threshold
55+
return output
56+
57+
class TaylorSeerAttentionCacheHook(ModelHook):
58+
_is_stateful = True
59+
60+
def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int]):
61+
super().__init__()
62+
self.fresh_threshold = fresh_threshold
63+
self.max_order = max_order
64+
self.current_timestep_callback = current_timestep_callback
65+
66+
def initialize_hook(self, module):
67+
self.img_state = TaylorSeerState()
68+
self.txt_state = TaylorSeerState()
69+
self.ip_state = TaylorSeerState()
70+
return module
71+
72+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
73+
current_step = self.current_timestep_callback()
74+
assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook"
75+
should_predict = self.img_state.predict_counter > 0
76+
77+
if not should_predict:
78+
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
79+
if len(attention_outputs) == 2:
80+
attn_output, context_attn_output = attention_outputs
81+
self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold)
82+
self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold)
83+
elif len(attention_outputs) == 3:
84+
attn_output, context_attn_output, ip_attn_output = attention_outputs
85+
self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold)
86+
self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold)
87+
self.ip_state.update(ip_attn_output, current_step, self.max_order, self.fresh_threshold)
88+
else:
89+
attn_output = self.img_state.predict(current_step, self.fresh_threshold)
90+
context_attn_output = self.txt_state.predict(current_step, self.fresh_threshold)
91+
ip_attn_output = self.ip_state.predict(current_step, self.fresh_threshold)
92+
attention_outputs = (attn_output, context_attn_output, ip_attn_output)
93+
return attention_outputs
94+
95+
def reset_state(self, module: torch.nn.Module) -> None:
96+
self.img_state.reset()
97+
self.txt_state.reset()
98+
self.ip_state.reset()
99+
return module
100+
101+
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
102+
for name, submodule in module.named_modules():
103+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
104+
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
105+
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
106+
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
107+
continue
108+
_apply_taylorseer_cache_on_attention_class(name, submodule, config)
109+
110+
111+
def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig):
112+
_apply_taylorseer_cache_hook(module, config)
113+
114+
115+
def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig):
116+
registry = HookRegistry.check_if_exists_or_initialize(module)
117+
hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback)
118+
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

0 commit comments

Comments
 (0)