Skip to content

Commit fa9a1f3

Browse files
committed
make style
1 parent 251ade1 commit fa9a1f3

File tree

3 files changed

+3
-10
lines changed

3 files changed

+3
-10
lines changed

src/diffusers/hooks/faster_cache.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
3434
"^blocks.*attn",
3535
"^transformer_blocks.*attn",
36-
"^single_transformer_blocks.*attn"
36+
"^single_transformer_blocks.*attn",
3737
)
3838
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
3939
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
@@ -483,10 +483,7 @@ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
483483
return module
484484

485485

486-
def apply_faster_cache(
487-
module: torch.nn.Module,
488-
config: FasterCacheConfig
489-
) -> None:
486+
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
490487
r"""
491488
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
492489

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ def reset_state(self, module: torch.nn.Module) -> None:
175175
return module
176176

177177

178-
def apply_pyramid_attention_broadcast(
179-
module: torch.nn.Module,
180-
config: PyramidAttentionBroadcastConfig
181-
):
178+
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
182179
r"""
183180
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
184181

tests/pipelines/test_pipelines_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2609,7 +2609,6 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs):
26092609
if not hasattr(module, "_diffusers_hook"):
26102610
continue
26112611

2612-
26132612
if name == "":
26142613
# Root denoiser module
26152614
state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state

0 commit comments

Comments
 (0)