Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def forward(
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def forward(
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> torch.Tensor:
residual = hidden_states
attention_kwargs = attention_kwargs or {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)

# norm & modulate
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/lumina_nextdit2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
) -> torch.Tensor:
"""
Perform a forward pass through the LuminaNextDiTBlock.

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_bria.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/transformer_cogview3plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import Dict, Union
from typing import Dict, Union, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -79,7 +79,7 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)

# norm & modulate
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def forward(
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def forward(
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
wtype = hidden_states.dtype
(
shift_msa_i,
Expand Down Expand Up @@ -592,7 +592,7 @@ def forward(
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)

Expand Down