Skip to content

Commit 1b61c34

Browse files
authored
Flux: pass joint_attention_kwargs when gradient_checkpointing
1 parent 00f95b9 commit 1b61c34

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def forward(
485485
encoder_hidden_states,
486486
temb,
487487
image_rotary_emb,
488+
joint_attention_kwargs,
488489
)
489490

490491
else:
@@ -516,6 +517,7 @@ def forward(
516517
hidden_states,
517518
temb,
518519
image_rotary_emb,
520+
joint_attention_kwargs,
519521
)
520522

521523
else:

0 commit comments

Comments
 (0)