diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 2a5fcf35498e..4f3253d82f3d 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -266,6 +266,20 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer + # we should have handled this in conversion script + def _get_pos_embed_from_transformer(self, transformer): + pos_embed = PatchEmbed( + height=transformer.config.sample_size, + width=transformer.config.sample_size, + patch_size=transformer.config.patch_size, + in_channels=transformer.config.in_channels, + embed_dim=transformer.inner_dim, + pos_embed_max_size=transformer.config.pos_embed_max_size, + ) + pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True) + return pos_embed + @classmethod def from_transformer( cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index b92dafffc715..8fd07fafc766 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -194,6 +194,19 @@ def __init__( super().__init__() if isinstance(controlnet, (list, tuple)): controlnet = SD3MultiControlNetModel(controlnet) + if isinstance(controlnet, SD3MultiControlNetModel): + for controlnet_model in controlnet.nets: + # for SD3.5 8b controlnet, it shares the pos_embed with the transformer + if ( + hasattr(controlnet_model.config, "use_pos_embed") + and controlnet_model.config.use_pos_embed is False + ): + pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer) + controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device) + elif isinstance(controlnet, SD3ControlNetModel): + if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: + pos_embed = controlnet._get_pos_embed_from_transformer(transformer) + controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device) self.register_modules( vae=vae, @@ -1042,15 +1055,9 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - if controlnet_config.use_pos_embed is False: - # sd35 (offical) 8b controlnet - controlnet_model_input = self.transformer.pos_embed(latent_model_input) - else: - controlnet_model_input = latent_model_input - # controlnet(s) inference control_block_samples = self.controlnet( - hidden_states=controlnet_model_input, + hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=controlnet_encoder_hidden_states, pooled_projections=controlnet_pooled_projections,