From 1ea9593779dbbfef3b3d208e0b37b61862a94705 Mon Sep 17 00:00:00 2001 From: pibbo88 Date: Mon, 23 Sep 2024 22:01:04 +0800 Subject: [PATCH] Fix the bug of sd3 controlnet training when using gradient_checkpointing. Refer to issue #9496 --- src/diffusers/models/controlnet_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index f19571dafb18..43b52a645a0d 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -336,7 +336,7 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states,