|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import functools |
16 | | -from typing import Any, Callable, Dict, Tuple, Union |
| 16 | +from typing import Any, Callable, Dict, Tuple |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 |
|
@@ -117,45 +117,72 @@ def reset_state(self, module): |
117 | 117 | class PyramidAttentionBroadcastHook(ModelHook): |
118 | 118 | def __init__( |
119 | 119 | self, |
120 | | - skip_range: int, |
121 | | - timestep_range: Tuple[int, int], |
122 | | - timestep_callback: Callable[[], Union[torch.LongTensor, int]], |
| 120 | + skip_callback: Callable[[torch.nn.Module], bool], |
| 121 | + # skip_range: int, |
| 122 | + # timestep_range: Tuple[int, int], |
| 123 | + # timestep_callback: Callable[[], Union[torch.LongTensor, int]], |
123 | 124 | ) -> None: |
124 | 125 | super().__init__() |
125 | 126 |
|
126 | | - self.skip_range = skip_range |
127 | | - self.timestep_range = timestep_range |
128 | | - self.timestep_callback = timestep_callback |
| 127 | + # self.skip_range = skip_range |
| 128 | + # self.timestep_range = timestep_range |
| 129 | + # self.timestep_callback = timestep_callback |
| 130 | + self.skip_callback = skip_callback |
129 | 131 |
|
130 | | - self.attention_cache = None |
| 132 | + self.cache = None |
131 | 133 | self._iteration = 0 |
132 | 134 |
|
133 | 135 | def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: |
134 | 136 | args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) |
135 | 137 |
|
136 | | - current_timestep = self.timestep_callback() |
137 | | - is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] |
138 | | - should_compute_attention = self._iteration % self.skip_range == 0 |
| 138 | + # current_timestep = self.timestep_callback() |
| 139 | + # is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] |
| 140 | + # should_compute_attention = self._iteration % self.skip_range == 0 |
139 | 141 |
|
140 | | - if not is_within_timestep_range or should_compute_attention: |
141 | | - output = module._old_forward(*args, **kwargs) |
142 | | - else: |
143 | | - output = self.attention_cache |
| 142 | + # if not is_within_timestep_range or should_compute_attention: |
| 143 | + # output = module._old_forward(*args, **kwargs) |
| 144 | + # else: |
| 145 | + # output = self.attention_cache |
144 | 146 |
|
145 | | - self._iteration = self._iteration + 1 |
| 147 | + if self.cache is not None and self.skip_callback(module): |
| 148 | + output = self.cache |
| 149 | + else: |
| 150 | + output = module._old_forward(*args, **kwargs) |
146 | 151 |
|
147 | 152 | return module._diffusers_hook.post_forward(module, output) |
148 | 153 |
|
149 | 154 | def post_forward(self, module: torch.nn.Module, output: Any) -> Any: |
150 | | - self.attention_cache = output |
| 155 | + self.cache = output |
151 | 156 | return output |
152 | 157 |
|
153 | 158 | def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: |
154 | | - self.attention_cache = None |
| 159 | + self.cache = None |
155 | 160 | self._iteration = 0 |
156 | 161 | return module |
157 | 162 |
|
158 | 163 |
|
| 164 | +class LayerSkipHook(ModelHook): |
| 165 | + def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None: |
| 166 | + super().__init__() |
| 167 | + |
| 168 | + self.skip_callback = skip_ |
| 169 | + |
| 170 | + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: |
| 171 | + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) |
| 172 | + |
| 173 | + if self.skip_callback(module): |
| 174 | + # We want to skip this layer, so we have to return the input of the current layer |
| 175 | + # as output of the next layer. But at this point, we don't have information about |
| 176 | + # the arguments required by next layer. Even if we did, order matters unless we |
| 177 | + # always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states, |
| 178 | + # temb, etc. TODO(aryan): implement correctly later |
| 179 | + output = None |
| 180 | + else: |
| 181 | + output = module._old_forward(*args, **kwargs) |
| 182 | + |
| 183 | + return module._diffusers_hook.post_forward(module, output) |
| 184 | + |
| 185 | + |
159 | 186 | def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): |
160 | 187 | r""" |
161 | 188 | Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove |
|
0 commit comments