Skip to content

Commit da475ec

Browse files
committed
Merge branch 'main' into ltx-integration
2 parents f18cf1a + cd34439 commit da475ec

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,20 @@ def _set_gradient_checkpointing(self, module, value=False):
266266
if hasattr(module, "gradient_checkpointing"):
267267
module.gradient_checkpointing = value
268268

269+
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
270+
# we should have handled this in conversion script
271+
def _get_pos_embed_from_transformer(self, transformer):
272+
pos_embed = PatchEmbed(
273+
height=transformer.config.sample_size,
274+
width=transformer.config.sample_size,
275+
patch_size=transformer.config.patch_size,
276+
in_channels=transformer.config.in_channels,
277+
embed_dim=transformer.inner_dim,
278+
pos_embed_max_size=transformer.config.pos_embed_max_size,
279+
)
280+
pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True)
281+
return pos_embed
282+
269283
@classmethod
270284
def from_transformer(
271285
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,19 @@ def __init__(
194194
super().__init__()
195195
if isinstance(controlnet, (list, tuple)):
196196
controlnet = SD3MultiControlNetModel(controlnet)
197+
if isinstance(controlnet, SD3MultiControlNetModel):
198+
for controlnet_model in controlnet.nets:
199+
# for SD3.5 8b controlnet, it shares the pos_embed with the transformer
200+
if (
201+
hasattr(controlnet_model.config, "use_pos_embed")
202+
and controlnet_model.config.use_pos_embed is False
203+
):
204+
pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer)
205+
controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device)
206+
elif isinstance(controlnet, SD3ControlNetModel):
207+
if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
208+
pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
209+
controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
197210

198211
self.register_modules(
199212
vae=vae,
@@ -1042,15 +1055,9 @@ def __call__(
10421055
controlnet_cond_scale = controlnet_cond_scale[0]
10431056
cond_scale = controlnet_cond_scale * controlnet_keep[i]
10441057

1045-
if controlnet_config.use_pos_embed is False:
1046-
# sd35 (offical) 8b controlnet
1047-
controlnet_model_input = self.transformer.pos_embed(latent_model_input)
1048-
else:
1049-
controlnet_model_input = latent_model_input
1050-
10511058
# controlnet(s) inference
10521059
control_block_samples = self.controlnet(
1053-
hidden_states=controlnet_model_input,
1060+
hidden_states=latent_model_input,
10541061
timestep=timestep,
10551062
encoder_hidden_states=controlnet_encoder_hidden_states,
10561063
pooled_projections=controlnet_pooled_projections,

0 commit comments

Comments
 (0)