Skip to content

Commit 68d7db3

Browse files
committed
fix many type hint errors
1 parent 3952475 commit 68d7db3

File tree

7 files changed

+15
-15
lines changed

7 files changed

+15
-15
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Any, Dict, Optional, Union
16+
from typing import Any, Dict, Optional, Union, Tuple
1717

1818
import torch
1919
import torch.nn as nn
@@ -92,7 +92,7 @@ def pe_selection_index_based_on_dim(self, h, w):
9292

9393
return selected_indices
9494

95-
def forward(self, latent):
95+
def forward(self, latent) -> torch.Tensor:
9696
batch_size, num_channels, height, width = latent.size()
9797
latent = latent.view(
9898
batch_size,
@@ -242,7 +242,7 @@ def forward(
242242
encoder_hidden_states: torch.FloatTensor,
243243
temb: torch.FloatTensor,
244244
attention_kwargs: Optional[Dict[str, Any]] = None,
245-
):
245+
) -> Tuple[torch.Tensor, torch.Tensor]:
246246
residual = hidden_states
247247
residual_context = encoder_hidden_states
248248
attention_kwargs = attention_kwargs or {}
@@ -472,7 +472,7 @@ def forward(
472472
timestep: torch.LongTensor = None,
473473
attention_kwargs: Optional[Dict[str, Any]] = None,
474474
return_dict: bool = True,
475-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
475+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
476476
if attention_kwargs is not None:
477477
attention_kwargs = attention_kwargs.copy()
478478
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def forward(
441441
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
442442
attention_kwargs: Optional[Dict[str, Any]] = None,
443443
return_dict: bool = True,
444-
):
444+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
445445
if attention_kwargs is not None:
446446
attention_kwargs = attention_kwargs.copy()
447447
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/consisid_transformer_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def forward(
691691
id_cond: Optional[torch.Tensor] = None,
692692
id_vit_hidden: Optional[torch.Tensor] = None,
693693
return_dict: bool = True,
694-
):
694+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
695695
if attention_kwargs is not None:
696696
attention_kwargs = attention_kwargs.copy()
697697
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/lumina_nextdit2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict, Optional, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -297,7 +297,7 @@ def forward(
297297
image_rotary_emb: torch.Tensor,
298298
cross_attention_kwargs: Dict[str, Any] = None,
299299
return_dict=True,
300-
) -> torch.Tensor:
300+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
301301
"""
302302
Forward pass of LuminaNextDiT.
303303

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, hidden_size, frequency_embedding_size=256):
5555
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
5656
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
5757

58-
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
58+
def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
5959
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
6060
t_emb = self.timestep_embedder(t_emb)
6161
return t_emb
@@ -87,7 +87,7 @@ def __init__(
8787
self.out_channels = out_channels
8888
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
8989

90-
def forward(self, latent):
90+
def forward(self, latent) -> torch.Tensor:
9191
latent = self.proj(latent)
9292
return latent
9393

@@ -786,7 +786,7 @@ def forward(
786786
attention_kwargs: Optional[Dict[str, Any]] = None,
787787
return_dict: bool = True,
788788
**kwargs,
789-
):
789+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
790790
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
791791

792792
if encoder_hidden_states is not None:

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def forward(
529529
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
530530
*args,
531531
**kwargs,
532-
) -> torch.Tensor:
532+
) -> Tuple[torch.Tensor, torch.Tensor]:
533533
text_seq_length = encoder_hidden_states.shape[1]
534534
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
535535

@@ -1038,7 +1038,7 @@ def forward(
10381038
guidance: torch.Tensor = None,
10391039
attention_kwargs: Optional[Dict[str, Any]] = None,
10401040
return_dict: bool = True,
1041-
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1041+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
10421042
if attention_kwargs is not None:
10431043
attention_kwargs = attention_kwargs.copy()
10441044
lora_scale = attention_kwargs.pop("scale", 1.0)

src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Tuple
15+
from typing import Any, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -216,7 +216,7 @@ def forward(
216216
indices_latents_history_4x: Optional[torch.Tensor] = None,
217217
attention_kwargs: Optional[Dict[str, Any]] = None,
218218
return_dict: bool = True,
219-
):
219+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
220220
if attention_kwargs is not None:
221221
attention_kwargs = attention_kwargs.copy()
222222
lora_scale = attention_kwargs.pop("scale", 1.0)

0 commit comments

Comments
 (0)