Skip to content

Commit 15fbf70

Browse files
committed
Support pass kwargs to cogvideox custom attention processor
1 parent fdcbbdf commit 15fbf70

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,6 +2813,8 @@ def __call__(
28132813
encoder_hidden_states: torch.Tensor,
28142814
attention_mask: Optional[torch.Tensor] = None,
28152815
image_rotary_emb: Optional[torch.Tensor] = None,
2816+
*args,
2817+
**kwargs,
28162818
) -> torch.Tensor:
28172819
text_seq_length = encoder_hidden_states.size(1)
28182820

@@ -2884,6 +2886,8 @@ def __call__(
28842886
encoder_hidden_states: torch.Tensor,
28852887
attention_mask: Optional[torch.Tensor] = None,
28862888
image_rotary_emb: Optional[torch.Tensor] = None,
2889+
*args,
2890+
**kwargs,
28872891
) -> torch.Tensor:
28882892
text_seq_length = encoder_hidden_states.size(1)
28892893

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ def forward(
120120
encoder_hidden_states: torch.Tensor,
121121
temb: torch.Tensor,
122122
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
123+
attention_kwargs: Optional[Dict[str, Any]] = None,
123124
) -> torch.Tensor:
124125
text_seq_length = encoder_hidden_states.size(1)
126+
attention_kwargs = attention_kwargs or {}
125127

126128
# norm & modulate
127129
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -133,6 +135,7 @@ def forward(
133135
hidden_states=norm_hidden_states,
134136
encoder_hidden_states=norm_encoder_hidden_states,
135137
image_rotary_emb=image_rotary_emb,
138+
**attention_kwargs,
136139
)
137140

138141
hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -497,6 +500,7 @@ def custom_forward(*inputs):
497500
encoder_hidden_states,
498501
emb,
499502
image_rotary_emb,
503+
attention_kwargs,
500504
**ckpt_kwargs,
501505
)
502506
else:
@@ -505,6 +509,7 @@ def custom_forward(*inputs):
505509
encoder_hidden_states=encoder_hidden_states,
506510
temb=emb,
507511
image_rotary_emb=image_rotary_emb,
512+
attention_kwargs=attention_kwargs,
508513
)
509514

510515
if not self.config.use_rotary_positional_embeddings:

0 commit comments

Comments
 (0)