Skip to content

Commit efb7a29

Browse files
authored
Fix many type hint errors (huggingface#12289)
* fix hidream type hint * fix hunyuan-video type hint * fix many type hint * fix many type hint errors * fix many type hint errors * fix many type hint errors * make stype & make quality
1 parent d06750a commit efb7a29

11 files changed

+30
-30
lines changed

src/diffusers/models/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def forward(
674674
encoder_hidden_states: torch.FloatTensor,
675675
temb: torch.FloatTensor,
676676
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
677-
):
677+
) -> Tuple[torch.Tensor, torch.Tensor]:
678678
joint_attention_kwargs = joint_attention_kwargs or {}
679679
if self.use_dual_attention:
680680
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Any, Dict, Optional, Union
16+
from typing import Any, Dict, Optional, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -92,7 +92,7 @@ def pe_selection_index_based_on_dim(self, h, w):
9292

9393
return selected_indices
9494

95-
def forward(self, latent):
95+
def forward(self, latent) -> torch.Tensor:
9696
batch_size, num_channels, height, width = latent.size()
9797
latent = latent.view(
9898
batch_size,
@@ -173,7 +173,7 @@ def forward(
173173
hidden_states: torch.FloatTensor,
174174
temb: torch.FloatTensor,
175175
attention_kwargs: Optional[Dict[str, Any]] = None,
176-
):
176+
) -> torch.Tensor:
177177
residual = hidden_states
178178
attention_kwargs = attention_kwargs or {}
179179

@@ -242,7 +242,7 @@ def forward(
242242
encoder_hidden_states: torch.FloatTensor,
243243
temb: torch.FloatTensor,
244244
attention_kwargs: Optional[Dict[str, Any]] = None,
245-
):
245+
) -> Tuple[torch.Tensor, torch.Tensor]:
246246
residual = hidden_states
247247
residual_context = encoder_hidden_states
248248
attention_kwargs = attention_kwargs or {}
@@ -472,7 +472,7 @@ def forward(
472472
timestep: torch.LongTensor = None,
473473
attention_kwargs: Optional[Dict[str, Any]] = None,
474474
return_dict: bool = True,
475-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
475+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
476476
if attention_kwargs is not None:
477477
attention_kwargs = attention_kwargs.copy()
478478
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def forward(
122122
temb: torch.Tensor,
123123
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
124124
attention_kwargs: Optional[Dict[str, Any]] = None,
125-
) -> torch.Tensor:
125+
) -> Tuple[torch.Tensor, torch.Tensor]:
126126
text_seq_length = encoder_hidden_states.size(1)
127127
attention_kwargs = attention_kwargs or {}
128128

@@ -441,7 +441,7 @@ def forward(
441441
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
442442
attention_kwargs: Optional[Dict[str, Any]] = None,
443443
return_dict: bool = True,
444-
):
444+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
445445
if attention_kwargs is not None:
446446
attention_kwargs = attention_kwargs.copy()
447447
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/consisid_transformer_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def forward(
315315
encoder_hidden_states: torch.Tensor,
316316
temb: torch.Tensor,
317317
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
318-
) -> torch.Tensor:
318+
) -> Tuple[torch.Tensor, torch.Tensor]:
319319
text_seq_length = encoder_hidden_states.size(1)
320320

321321
# norm & modulate
@@ -691,7 +691,7 @@ def forward(
691691
id_cond: Optional[torch.Tensor] = None,
692692
id_vit_hidden: Optional[torch.Tensor] = None,
693693
return_dict: bool = True,
694-
):
694+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
695695
if attention_kwargs is not None:
696696
attention_kwargs = attention_kwargs.copy()
697697
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/lumina_nextdit2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -124,7 +124,7 @@ def forward(
124124
encoder_mask: torch.Tensor,
125125
temb: torch.Tensor,
126126
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
127-
):
127+
) -> torch.Tensor:
128128
"""
129129
Perform a forward pass through the LuminaNextDiTBlock.
130130
@@ -297,7 +297,7 @@ def forward(
297297
image_rotary_emb: torch.Tensor,
298298
cross_attention_kwargs: Dict[str, Any] = None,
299299
return_dict=True,
300-
) -> torch.Tensor:
300+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
301301
"""
302302
Forward pass of LuminaNextDiT.
303303

src/diffusers/models/transformers/transformer_bria.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def forward(
472472
temb: torch.Tensor,
473473
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
474474
attention_kwargs: Optional[Dict[str, Any]] = None,
475-
) -> torch.Tensor:
475+
) -> Tuple[torch.Tensor, torch.Tensor]:
476476
text_seq_len = encoder_hidden_states.shape[1]
477477
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
478478

@@ -588,7 +588,7 @@ def forward(
588588
return_dict: bool = True,
589589
controlnet_block_samples=None,
590590
controlnet_single_block_samples=None,
591-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
591+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
592592
"""
593593
The [`BriaTransformer2DModel`] forward method.
594594

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict, Union
16+
from typing import Dict, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -79,7 +79,7 @@ def forward(
7979
hidden_states: torch.Tensor,
8080
encoder_hidden_states: torch.Tensor,
8181
emb: torch.Tensor,
82-
) -> torch.Tensor:
82+
) -> Tuple[torch.Tensor, torch.Tensor]:
8383
text_seq_length = encoder_hidden_states.size(1)
8484

8585
# norm & modulate
@@ -293,7 +293,7 @@ def forward(
293293
target_size: torch.Tensor,
294294
crop_coords: torch.Tensor,
295295
return_dict: bool = True,
296-
) -> Union[torch.Tensor, Transformer2DModelOutput]:
296+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
297297
"""
298298
The [`CogView3PlusTransformer2DModel`] forward method.
299299

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def forward(
494494
] = None,
495495
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
496496
attention_kwargs: Optional[Dict[str, Any]] = None,
497-
) -> torch.Tensor:
497+
) -> Tuple[torch.Tensor, torch.Tensor]:
498498
# 1. Timestep conditioning
499499
(
500500
norm_hidden_states,
@@ -717,7 +717,7 @@ def forward(
717717
image_rotary_emb: Optional[
718718
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
719719
] = None,
720-
) -> Union[torch.Tensor, Transformer2DModelOutput]:
720+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
721721
if attention_kwargs is not None:
722722
attention_kwargs = attention_kwargs.copy()
723723
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, hidden_size, frequency_embedding_size=256):
5555
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
5656
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
5757

58-
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
58+
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
5959
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
6060
t_emb = self.timestep_embedder(t_emb)
6161
return t_emb
@@ -87,7 +87,7 @@ def __init__(
8787
self.out_channels = out_channels
8888
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
8989

90-
def forward(self, latent):
90+
def forward(self, latent) -> torch.Tensor:
9191
latent = self.proj(latent)
9292
return latent
9393

@@ -534,7 +534,7 @@ def forward(
534534
encoder_hidden_states: Optional[torch.Tensor] = None,
535535
temb: Optional[torch.Tensor] = None,
536536
image_rotary_emb: torch.Tensor = None,
537-
) -> torch.Tensor:
537+
) -> Tuple[torch.Tensor, torch.Tensor]:
538538
wtype = hidden_states.dtype
539539
(
540540
shift_msa_i,
@@ -592,7 +592,7 @@ def forward(
592592
encoder_hidden_states: Optional[torch.Tensor] = None,
593593
temb: Optional[torch.Tensor] = None,
594594
image_rotary_emb: torch.Tensor = None,
595-
) -> torch.Tensor:
595+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
596596
return self.block(
597597
hidden_states=hidden_states,
598598
hidden_states_masks=hidden_states_masks,
@@ -786,7 +786,7 @@ def forward(
786786
attention_kwargs: Optional[Dict[str, Any]] = None,
787787
return_dict: bool = True,
788788
**kwargs,
789-
):
789+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
790790
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
791791

792792
if encoder_hidden_states is not None:

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def forward(
529529
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
530530
*args,
531531
**kwargs,
532-
) -> torch.Tensor:
532+
) -> Tuple[torch.Tensor, torch.Tensor]:
533533
text_seq_length = encoder_hidden_states.shape[1]
534534
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
535535

@@ -684,7 +684,7 @@ def forward(
684684
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
685685
token_replace_emb: torch.Tensor = None,
686686
num_tokens: int = None,
687-
) -> torch.Tensor:
687+
) -> Tuple[torch.Tensor, torch.Tensor]:
688688
text_seq_length = encoder_hidden_states.shape[1]
689689
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
690690

@@ -1038,7 +1038,7 @@ def forward(
10381038
guidance: torch.Tensor = None,
10391039
attention_kwargs: Optional[Dict[str, Any]] = None,
10401040
return_dict: bool = True,
1041-
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1041+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
10421042
if attention_kwargs is not None:
10431043
attention_kwargs = attention_kwargs.copy()
10441044
lora_scale = attention_kwargs.pop("scale", 1.0)

0 commit comments

Comments
 (0)