Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,6 +2813,8 @@ def __call__(
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

Expand Down Expand Up @@ -2884,6 +2886,8 @@ def __call__(
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}

# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
Expand All @@ -133,6 +135,7 @@ def forward(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)

hidden_states = hidden_states + gate_msa * attn_hidden_states
Expand Down Expand Up @@ -497,6 +500,7 @@ def custom_forward(*inputs):
encoder_hidden_states,
emb,
image_rotary_emb,
attention_kwargs,
**ckpt_kwargs,
)
else:
Expand All @@ -505,6 +509,7 @@ def custom_forward(*inputs):
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
)

if not self.config.use_rotary_positional_embeddings:
Expand Down