@@ -236,7 +236,7 @@ def uncond_skip_callback(module: nn.Module) -> bool:
236236 is_using_classifier_free_guidance = pipeline .do_classifier_free_guidance
237237 if not is_using_classifier_free_guidance :
238238 return False
239-
239+
240240 # We skip the unconditional branch only if the following conditions are met:
241241 # 1. We have completed at least one iteration of the denoiser
242242 # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
@@ -326,7 +326,7 @@ def skip_callback(module: nn.Module) -> bool:
326326
327327class FasterCacheModelHook (ModelHook ):
328328 _is_stateful = True
329-
329+
330330 def __init__ (self , uncond_cond_input_kwargs_identifiers : List [str ], tensor_format : str ) -> None :
331331 super ().__init__ ()
332332
@@ -397,7 +397,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
397397 else :
398398 # TODO(aryan): remove later
399399 logger .debug ("Computing unconditional branch" )
400-
400+
401401 uncond_states , cond_states = hidden_states .chunk (2 , dim = 0 )
402402 if self .tensor_format == "BCFHW" :
403403 uncond_states = uncond_states .permute (0 , 2 , 1 , 3 , 4 )
@@ -412,16 +412,16 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
412412 state .high_frequency_delta = high_freq_uncond - high_freq_cond
413413
414414 state .iteration += 1
415- output = (hidden_states , * output [1 :]) if isinstance (output , tuple ) else hidden_states
415+ output = (hidden_states , * output [1 :]) if isinstance (output , tuple ) else hidden_states
416416 return output
417-
417+
418418 def reset_state (self , module : nn .Module ) -> None :
419419 module ._fastercache_state .reset ()
420420
421421
422422class FasterCacheBlockHook (ModelHook ):
423423 _is_stateful = True
424-
424+
425425 def new_forward (self , module : nn .Module , * args , ** kwargs ) -> Any :
426426 args , kwargs = module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
427427 state : FasterCacheState = module ._fastercache_state
@@ -443,7 +443,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
443443 # TODO(aryan): remove later
444444 logger .debug ("Skipping layer computation" )
445445 t_2_output , t_output = state .cache
446-
446+
447447 # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed
448448 if t_2_output .size (0 ) != batch_size :
449449 # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
@@ -455,7 +455,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
455455 # take the conditional branch outputs.
456456 assert t_output .size (0 ) == 2 * batch_size
457457 t_output = t_output [batch_size :]
458-
458+
459459 output = t_output + (t_output - t_2_output ) * state .weight_callback (module )
460460 else :
461461 output = module ._old_forward (* args , ** kwargs )
@@ -465,7 +465,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
465465 cache_output = output
466466 if output .size (0 ) == state .batch_size :
467467 cache_output = cache_output .chunk (2 , dim = 0 )[1 ]
468-
468+
469469 # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference
470470 # and only-conditional branch inference.
471471 assert 2 * cache_output .size (0 ) == state .batch_size
@@ -477,7 +477,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
477477
478478 state .iteration += 1
479479 return module ._diffusers_hook .post_forward (module , output )
480-
480+
481481 def reset_state (self , module : nn .Module ) -> None :
482482 module ._fastercache_state .reset ()
483483
0 commit comments