|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
16 | | -from typing import Any, Dict, Optional, Union |
| 16 | +from typing import Any, Dict, Optional, Tuple, Union |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
@@ -92,7 +92,7 @@ def pe_selection_index_based_on_dim(self, h, w): |
92 | 92 |
|
93 | 93 | return selected_indices |
94 | 94 |
|
95 | | - def forward(self, latent): |
| 95 | + def forward(self, latent) -> torch.Tensor: |
96 | 96 | batch_size, num_channels, height, width = latent.size() |
97 | 97 | latent = latent.view( |
98 | 98 | batch_size, |
@@ -173,7 +173,7 @@ def forward( |
173 | 173 | hidden_states: torch.FloatTensor, |
174 | 174 | temb: torch.FloatTensor, |
175 | 175 | attention_kwargs: Optional[Dict[str, Any]] = None, |
176 | | - ): |
| 176 | + ) -> torch.Tensor: |
177 | 177 | residual = hidden_states |
178 | 178 | attention_kwargs = attention_kwargs or {} |
179 | 179 |
|
@@ -242,7 +242,7 @@ def forward( |
242 | 242 | encoder_hidden_states: torch.FloatTensor, |
243 | 243 | temb: torch.FloatTensor, |
244 | 244 | attention_kwargs: Optional[Dict[str, Any]] = None, |
245 | | - ): |
| 245 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
246 | 246 | residual = hidden_states |
247 | 247 | residual_context = encoder_hidden_states |
248 | 248 | attention_kwargs = attention_kwargs or {} |
@@ -472,7 +472,7 @@ def forward( |
472 | 472 | timestep: torch.LongTensor = None, |
473 | 473 | attention_kwargs: Optional[Dict[str, Any]] = None, |
474 | 474 | return_dict: bool = True, |
475 | | - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: |
| 475 | + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: |
476 | 476 | if attention_kwargs is not None: |
477 | 477 | attention_kwargs = attention_kwargs.copy() |
478 | 478 | lora_scale = attention_kwargs.pop("scale", 1.0) |
|
0 commit comments