Skip to content

Commit 82d85bd

Browse files
committed
make style
1 parent 6047114 commit 82d85bd

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

src/diffusers/models/hooks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7777
The module detached from this hook.
7878
"""
7979
return module
80-
80+
8181
def reset_state(self):
8282
if self._is_stateful:
8383
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
@@ -108,7 +108,7 @@ def detach_hook(self, module):
108108
for hook in self.hooks:
109109
module = hook.detach_hook(module)
110110
return module
111-
111+
112112
def reset_state(self):
113113
for hook in self.hooks:
114114
if hook._is_stateful:
@@ -216,7 +216,9 @@ def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
216216
module (`torch.nn.Module`):
217217
The module to reset the stateful hooks from.
218218
"""
219-
if hasattr(module, "_diffusers_hook") and (module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)):
219+
if hasattr(module, "_diffusers_hook") and (
220+
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
221+
):
220222
module._diffusers_hook.reset_state(module)
221223

222224
if recurse:

src/diffusers/pipelines/faster_cache_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def uncond_skip_callback(module: nn.Module) -> bool:
236236
is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance
237237
if not is_using_classifier_free_guidance:
238238
return False
239-
239+
240240
# We skip the unconditional branch only if the following conditions are met:
241241
# 1. We have completed at least one iteration of the denoiser
242242
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
@@ -326,7 +326,7 @@ def skip_callback(module: nn.Module) -> bool:
326326

327327
class FasterCacheModelHook(ModelHook):
328328
_is_stateful = True
329-
329+
330330
def __init__(self, uncond_cond_input_kwargs_identifiers: List[str], tensor_format: str) -> None:
331331
super().__init__()
332332

@@ -397,7 +397,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
397397
else:
398398
# TODO(aryan): remove later
399399
logger.debug("Computing unconditional branch")
400-
400+
401401
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
402402
if self.tensor_format == "BCFHW":
403403
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
@@ -412,16 +412,16 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
412412
state.high_frequency_delta = high_freq_uncond - high_freq_cond
413413

414414
state.iteration += 1
415-
output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states
415+
output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states
416416
return output
417-
417+
418418
def reset_state(self, module: nn.Module) -> None:
419419
module._fastercache_state.reset()
420420

421421

422422
class FasterCacheBlockHook(ModelHook):
423423
_is_stateful = True
424-
424+
425425
def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
426426
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
427427
state: FasterCacheState = module._fastercache_state
@@ -443,7 +443,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
443443
# TODO(aryan): remove later
444444
logger.debug("Skipping layer computation")
445445
t_2_output, t_output = state.cache
446-
446+
447447
# TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed
448448
if t_2_output.size(0) != batch_size:
449449
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
@@ -455,7 +455,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
455455
# take the conditional branch outputs.
456456
assert t_output.size(0) == 2 * batch_size
457457
t_output = t_output[batch_size:]
458-
458+
459459
output = t_output + (t_output - t_2_output) * state.weight_callback(module)
460460
else:
461461
output = module._old_forward(*args, **kwargs)
@@ -465,7 +465,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
465465
cache_output = output
466466
if output.size(0) == state.batch_size:
467467
cache_output = cache_output.chunk(2, dim=0)[1]
468-
468+
469469
# Just to be safe that the output is of the correct size for both unconditional-conditional branch inference
470470
# and only-conditional branch inference.
471471
assert 2 * cache_output.size(0) == state.batch_size
@@ -477,7 +477,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
477477

478478
state.iteration += 1
479479
return module._diffusers_hook.post_forward(module, output)
480-
480+
481481
def reset_state(self, module: nn.Module) -> None:
482482
module._fastercache_state.reset()
483483

0 commit comments

Comments
 (0)