Skip to content

Commit 8f495b6

Browse files
author
toilaluan
committed
make compatible with any tuple size returned
1 parent fe20f97 commit 8f495b6

File tree

1 file changed

+26
-29
lines changed

1 file changed

+26
-29
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ class TaylorSeerCacheConfig:
2323

2424
class TaylorSeerState:
2525
def __init__(self):
26-
self.predict_counter: int = 1
26+
self.predict_counter: int = 0
2727
self.last_step: int = 1000
2828
self.taylor_factors: dict[int, torch.Tensor] = {}
2929

3030
def reset(self):
31-
self.predict_counter = 1
31+
self.predict_counter = 0
3232
self.last_step = 1000
3333
self.taylor_factors = {}
3434

@@ -43,15 +43,15 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr
4343
break
4444
self.taylor_factors = new_taylor_factors
4545
self.last_step = current_step
46-
self.predict_counter = (self.predict_counter + 1) % refresh_threshold
46+
self.predict_counter = refresh_threshold
4747

48-
def predict(self, current_step: int, refresh_threshold: int):
48+
def predict(self, current_step: int):
4949
k = current_step - self.last_step
5050
device = self.taylor_factors[0].device
5151
output = torch.zeros_like(self.taylor_factors[0], device=device)
5252
for i in range(len(self.taylor_factors)):
5353
output += self.taylor_factors[i] * (k ** i) / math.factorial(i)
54-
self.predict_counter = (self.predict_counter + 1) % refresh_threshold
54+
self.predict_counter -= 1
5555
return output
5656

5757
class TaylorSeerAttentionCacheHook(ModelHook):
@@ -64,47 +64,44 @@ def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callba
6464
self.current_timestep_callback = current_timestep_callback
6565

6666
def initialize_hook(self, module):
67-
self.img_state = TaylorSeerState()
68-
self.txt_state = TaylorSeerState()
69-
self.ip_state = TaylorSeerState()
67+
self.states = None
68+
self.num_outputs = None
7069
return module
7170

7271
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
7372
current_step = self.current_timestep_callback()
7473
assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook"
75-
should_predict = self.img_state.predict_counter > 0
74+
75+
if self.states is None:
76+
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
77+
self.num_outputs = len(attention_outputs)
78+
self.states = [TaylorSeerState() for _ in range(self.num_outputs)]
79+
for i, feat in enumerate(attention_outputs):
80+
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
81+
return attention_outputs
82+
83+
should_predict = self.states[0].predict_counter > 0
7684

7785
if not should_predict:
7886
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
79-
if len(attention_outputs) == 2:
80-
attn_output, context_attn_output = attention_outputs
81-
self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold)
82-
self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold)
83-
elif len(attention_outputs) == 3:
84-
attn_output, context_attn_output, ip_attn_output = attention_outputs
85-
self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold)
86-
self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold)
87-
self.ip_state.update(ip_attn_output, current_step, self.max_order, self.fresh_threshold)
88-
else:
89-
attn_output = self.img_state.predict(current_step, self.fresh_threshold)
90-
context_attn_output = self.txt_state.predict(current_step, self.fresh_threshold)
91-
ip_attn_output = self.ip_state.predict(current_step, self.fresh_threshold)
92-
attention_outputs = (attn_output, context_attn_output, ip_attn_output)
87+
for i, feat in enumerate(attention_outputs):
88+
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
9389
return attention_outputs
90+
else:
91+
predicted_outputs = [state.predict(current_step) for state in self.states]
92+
return tuple(predicted_outputs)
9493

9594
def reset_state(self, module: torch.nn.Module) -> None:
96-
self.img_state.reset()
97-
self.txt_state.reset()
98-
self.ip_state.reset()
95+
if self.states is not None:
96+
for state in self.states:
97+
state.reset()
9998
return module
10099

101100
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
102101
for name, submodule in module.named_modules():
103102
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
104-
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
105-
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
106-
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
107103
continue
104+
print(f"Applying TaylorSeer cache to {name}")
108105
_apply_taylorseer_cache_on_attention_class(name, submodule, config)
109106

110107

0 commit comments

Comments
 (0)