Skip to content

Commit d929ab2

Browse files
author
toilaluan
committed
apply ruff
1 parent 9290b58 commit d929ab2

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from dataclasses import dataclass
3-
from typing import Callable, Optional, List, Dict, Tuple
3+
from typing import Optional, List, Dict, Tuple
44
from .hooks import ModelHook
55
import math
66
from ..models.attention import Attention
@@ -11,11 +11,9 @@
1111
from ..hooks import HookRegistry
1212
from ..utils import logging
1313
import re
14-
from collections import defaultdict
1514

1615

17-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18-
16+
logger = logging.get_logger(__name__)
1917

2018
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
2119

@@ -70,6 +68,7 @@ def __repr__(self) -> str:
7068
def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]:
7169
return _CACHE_TEMPLATES
7270

71+
7372
class TaylorSeerOutputState:
7473
"""
7574
Manages the state for Taylor series-based prediction of a single attention output.
@@ -219,9 +218,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
219218
module_dtype = attention_outputs[0].dtype
220219
self.num_outputs = len(attention_outputs)
221220
self.states = [
222-
TaylorSeerOutputState(
223-
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip
224-
)
221+
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip)
225222
for _ in range(self.num_outputs)
226223
]
227224
for i, features in enumerate(attention_outputs):
@@ -249,7 +246,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
249246
attention_outputs = list(attention_outputs)
250247
is_first_update = self.step_counter == 0 # Only True for the very first step
251248
for i, features in enumerate(attention_outputs):
252-
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
249+
self.states[i].update(
250+
features, self.step_counter, self.max_order, self.predict_steps, is_first_update
251+
)
253252
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
254253
else:
255254
# Predict using Taylor series
@@ -330,4 +329,4 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
330329
is_skip=is_skip,
331330
)
332331

333-
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
332+
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

0 commit comments

Comments
 (0)