@@ -115,18 +115,9 @@ def reset_state(self, module):
115115
116116
117117class PyramidAttentionBroadcastHook (ModelHook ):
118- def __init__ (
119- self ,
120- skip_callback : Callable [[torch .nn .Module ], bool ],
121- # skip_range: int,
122- # timestep_range: Tuple[int, int],
123- # timestep_callback: Callable[[], Union[torch.LongTensor, int]],
124- ) -> None :
118+ def __init__ (self , skip_callback : Callable [[torch .nn .Module ], bool ]) -> None :
125119 super ().__init__ ()
126120
127- # self.skip_range = skip_range
128- # self.timestep_range = timestep_range
129- # self.timestep_callback = timestep_callback
130121 self .skip_callback = skip_callback
131122
132123 self .cache = None
@@ -135,15 +126,6 @@ def __init__(
135126 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ) -> Any :
136127 args , kwargs = module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
137128
138- # current_timestep = self.timestep_callback()
139- # is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
140- # should_compute_attention = self._iteration % self.skip_range == 0
141-
142- # if not is_within_timestep_range or should_compute_attention:
143- # output = module._old_forward(*args, **kwargs)
144- # else:
145- # output = self.attention_cache
146-
147129 if self .cache is not None and self .skip_callback (module ):
148130 output = self .cache
149131 else :
0 commit comments