Skip to content

Commit 0602044

Browse files
author
toilaluan
committed
still update in warmup steps
1 parent 8f80072 commit 0602044

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
8585

8686
if self.states is None:
8787
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
88-
if self.warmup_steps_counter < self.warmup_steps:
89-
logger.debug(f"warmup_steps_counter: {self.warmup_steps_counter}")
90-
self.warmup_steps_counter += 1
91-
return attention_outputs
9288
if isinstance(attention_outputs, torch.Tensor):
9389
attention_outputs = [attention_outputs]
9490
self.num_outputs = len(attention_outputs)
@@ -97,7 +93,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
9793
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
9894
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
9995

100-
should_predict = self.states[0].predict_counter > 0
96+
should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps
10197

10298
if not should_predict:
10399
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
@@ -108,9 +104,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
108104
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
109105
else:
110106
predicted_outputs = [state.predict(current_step) for state in self.states]
111-
if len(predicted_outputs) == 1:
112-
return predicted_outputs[0]
113-
return tuple(predicted_outputs)
107+
return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs
114108

115109
def reset_state(self, module: torch.nn.Module) -> None:
116110
if self.states is not None:

0 commit comments

Comments
 (0)