|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -from typing import Optional, Tuple |
| 16 | +from typing import Any, Dict, Optional, Tuple |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
20 | 20 |
|
21 | 21 | from ...configuration_utils import ConfigMixin, register_to_config |
22 | | -from ...utils import logging |
| 22 | +from ...utils import is_torch_version, logging |
23 | 23 | from ...utils.torch_utils import maybe_allow_in_graph |
24 | 24 | from ..attention import FeedForward |
25 | 25 | from ..attention_processor import Attention, MochiAttnProcessor2_0 |
@@ -131,7 +131,7 @@ def forward( |
131 | 131 | ) * torch.tanh(enc_gate_msa).unsqueeze(1) |
132 | 132 | norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) |
133 | 133 | context_ff_output = self.ff_context(norm_encoder_hidden_states) |
134 | | - encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(0) |
| 134 | + encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(1) |
135 | 135 |
|
136 | 136 | return hidden_states, encoder_hidden_states |
137 | 137 |
|
@@ -248,6 +248,12 @@ def __init__( |
248 | 248 | ) |
249 | 249 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) |
250 | 250 |
|
| 251 | + self.gradient_checkpointing = False |
| 252 | + |
| 253 | + def _set_gradient_checkpointing(self, module, value=False): |
| 254 | + if hasattr(module, "gradient_checkpointing"): |
| 255 | + module.gradient_checkpointing = value |
| 256 | + |
251 | 257 | def forward( |
252 | 258 | self, |
253 | 259 | hidden_states: torch.Tensor, |
@@ -280,13 +286,30 @@ def forward( |
280 | 286 | ) |
281 | 287 |
|
282 | 288 | for i, block in enumerate(self.transformer_blocks): |
283 | | - hidden_states, encoder_hidden_states = block( |
284 | | - hidden_states=hidden_states, |
285 | | - encoder_hidden_states=encoder_hidden_states, |
286 | | - temb=temb, |
287 | | - image_rotary_emb=image_rotary_emb, |
288 | | - ) |
289 | | - print(hidden_states.mean(), hidden_states.std()) |
| 289 | + if self.gradient_checkpointing: |
| 290 | + |
| 291 | + def create_custom_forward(module): |
| 292 | + def custom_forward(*inputs): |
| 293 | + return module(*inputs) |
| 294 | + |
| 295 | + return custom_forward |
| 296 | + |
| 297 | + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| 298 | + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
| 299 | + create_custom_forward(block), |
| 300 | + hidden_states, |
| 301 | + encoder_hidden_states, |
| 302 | + temb, |
| 303 | + image_rotary_emb, |
| 304 | + **ckpt_kwargs, |
| 305 | + ) |
| 306 | + else: |
| 307 | + hidden_states, encoder_hidden_states = block( |
| 308 | + hidden_states=hidden_states, |
| 309 | + encoder_hidden_states=encoder_hidden_states, |
| 310 | + temb=temb, |
| 311 | + image_rotary_emb=image_rotary_emb, |
| 312 | + ) |
290 | 313 |
|
291 | 314 | hidden_states = self.norm_out(hidden_states, temb) |
292 | 315 | hidden_states = self.proj_out(hidden_states) |
|
0 commit comments