Skip to content

Commit 0fba4cc

Browse files
committed
wip
1 parent c96bfa5 commit 0fba4cc

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,18 @@ def custom_forward(*inputs):
379379
return custom_forward
380380

381381
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
382-
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
383-
create_custom_forward(block),
384-
hidden_states,
385-
encoder_hidden_states,
386-
temb,
387-
**ckpt_kwargs,
388-
)
382+
if self.context_embedder is not None:
383+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
384+
create_custom_forward(block),
385+
hidden_states,
386+
encoder_hidden_states,
387+
temb,
388+
**ckpt_kwargs,
389+
)
390+
else:
391+
hidden_states = torch.utils.checkpoint.checkpoint(
392+
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
393+
)
389394

390395
else:
391396
if self.context_embedder is not None:

0 commit comments

Comments
 (0)