Skip to content

Commit 3376252

Browse files
RefractAIsayakpaulyiyixuxu
authored
Fix gradient checkpointing issue for Stable Diffusion 3 (#8542)
Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 16170c6 commit 3376252

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def custom_forward(*inputs):
306306
return custom_forward
307307

308308
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
309-
hidden_states = torch.utils.checkpoint.checkpoint(
309+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
310310
create_custom_forward(block),
311311
hidden_states,
312312
encoder_hidden_states,

0 commit comments

Comments
 (0)