Skip to content

Commit 8f80072

Browse files
author
toilaluan
committed
use logger for printing, add warmup feature
1 parent 8f495b6 commit 8f80072

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
_ATTENTION_CLASSES,
1313
)
1414
from ..hooks import HookRegistry
15-
15+
from ..utils import logging
16+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1617
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
1718

1819
@dataclass
1920
class TaylorSeerCacheConfig:
21+
warmup_steps: int = 3 # full compute some first steps
2022
fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
2123
max_order: int = 1 # order of Taylor series expansion
2224
current_timestep_callback: Callable[[], int] = None
@@ -33,7 +35,9 @@ def reset(self):
3335
self.taylor_factors = {}
3436

3537
def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int):
36-
N = math.abs(current_step - self.last_step)
38+
logger.debug("="*10)
39+
N = self.last_step - current_step
40+
logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}")
3741
# initialize the first order taylor factors
3842
new_taylor_factors = {0: features}
3943
for i in range(max_order):
@@ -44,6 +48,9 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr
4448
self.taylor_factors = new_taylor_factors
4549
self.last_step = current_step
4650
self.predict_counter = refresh_threshold
51+
logger.debug(f"last_step: {self.last_step}")
52+
logger.debug(f"predict_counter: {self.predict_counter}")
53+
logger.debug("="*10)
4754

4855
def predict(self, current_step: int):
4956
k = current_step - self.last_step
@@ -52,20 +59,24 @@ def predict(self, current_step: int):
5259
for i in range(len(self.taylor_factors)):
5360
output += self.taylor_factors[i] * (k ** i) / math.factorial(i)
5461
self.predict_counter -= 1
62+
logger.debug(f"predict_counter: {self.predict_counter}")
63+
logger.debug(f"k: {k}")
5564
return output
5665

5766
class TaylorSeerAttentionCacheHook(ModelHook):
5867
_is_stateful = True
5968

60-
def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int]):
69+
def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int):
6170
super().__init__()
6271
self.fresh_threshold = fresh_threshold
6372
self.max_order = max_order
6473
self.current_timestep_callback = current_timestep_callback
74+
self.warmup_steps = warmup_steps
6575

6676
def initialize_hook(self, module):
6777
self.states = None
6878
self.num_outputs = None
79+
self.warmup_steps_counter = 0
6980
return module
7081

7182
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -74,21 +85,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
7485

7586
if self.states is None:
7687
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
88+
if self.warmup_steps_counter < self.warmup_steps:
89+
logger.debug(f"warmup_steps_counter: {self.warmup_steps_counter}")
90+
self.warmup_steps_counter += 1
91+
return attention_outputs
92+
if isinstance(attention_outputs, torch.Tensor):
93+
attention_outputs = [attention_outputs]
7794
self.num_outputs = len(attention_outputs)
7895
self.states = [TaylorSeerState() for _ in range(self.num_outputs)]
7996
for i, feat in enumerate(attention_outputs):
8097
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
81-
return attention_outputs
98+
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
8299

83100
should_predict = self.states[0].predict_counter > 0
84101

85102
if not should_predict:
86103
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
104+
if isinstance(attention_outputs, torch.Tensor):
105+
attention_outputs = [attention_outputs]
87106
for i, feat in enumerate(attention_outputs):
88107
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
89-
return attention_outputs
108+
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
90109
else:
91110
predicted_outputs = [state.predict(current_step) for state in self.states]
111+
if len(predicted_outputs) == 1:
112+
return predicted_outputs[0]
92113
return tuple(predicted_outputs)
93114

94115
def reset_state(self, module: torch.nn.Module) -> None:
@@ -101,7 +122,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
101122
for name, submodule in module.named_modules():
102123
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
103124
continue
104-
print(f"Applying TaylorSeer cache to {name}")
125+
logger.debug(f"Applying TaylorSeer cache to {name}")
105126
_apply_taylorseer_cache_on_attention_class(name, submodule, config)
106127

107128

@@ -111,5 +132,5 @@ def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, con
111132

112133
def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig):
113134
registry = HookRegistry.check_if_exists_or_initialize(module)
114-
hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback)
135+
hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps)
115136
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

0 commit comments

Comments
 (0)