From 86a557dde14a7e03d486eaa426c15b7769a07ad7 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 1 Dec 2024 21:19:51 +0100 Subject: [PATCH 1/2] add --- .../models/controlnets/controlnet_sd3.py | 14 +++++++++++++ .../pipeline_stable_diffusion_3_controlnet.py | 20 ++++++++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 2a5fcf35498e..4f3253d82f3d 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -266,6 +266,20 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer + # we should have handled this in conversion script + def _get_pos_embed_from_transformer(self, transformer): + pos_embed = PatchEmbed( + height=transformer.config.sample_size, + width=transformer.config.sample_size, + patch_size=transformer.config.patch_size, + in_channels=transformer.config.in_channels, + embed_dim=transformer.inner_dim, + pos_embed_max_size=transformer.config.pos_embed_max_size, + ) + pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True) + return pos_embed + @classmethod def from_transformer( cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index b92dafffc715..a92b53c8770f 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -193,7 +193,19 @@ def __init__( ): super().__init__() if isinstance(controlnet, (list, tuple)): + for controlnet_model in controlnet: + # for SD3.5 8b controlnet, it shares the pos_embed with the transformer + if ( + hasattr(controlnet_model.config, "use_pos_embed") + and controlnet_model.config.use_pos_embed is False + ): + pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer) + controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device) controlnet = SD3MultiControlNetModel(controlnet) + else: + if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: + pos_embed = controlnet._get_pos_embed_from_transformer(transformer) + controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device) self.register_modules( vae=vae, @@ -1042,15 +1054,9 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] - if controlnet_config.use_pos_embed is False: - # sd35 (offical) 8b controlnet - controlnet_model_input = self.transformer.pos_embed(latent_model_input) - else: - controlnet_model_input = latent_model_input - # controlnet(s) inference control_block_samples = self.controlnet( - hidden_states=controlnet_model_input, + hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=controlnet_encoder_hidden_states, pooled_projections=controlnet_pooled_projections, From 8098ea20ec08a627080f7e82ff8b42ceccd34e62 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 2 Dec 2024 20:48:55 +0100 Subject: [PATCH 2/2] update --- .../pipeline_stable_diffusion_3_controlnet.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index a92b53c8770f..8fd07fafc766 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -193,7 +193,9 @@ def __init__( ): super().__init__() if isinstance(controlnet, (list, tuple)): - for controlnet_model in controlnet: + controlnet = SD3MultiControlNetModel(controlnet) + if isinstance(controlnet, SD3MultiControlNetModel): + for controlnet_model in controlnet.nets: # for SD3.5 8b controlnet, it shares the pos_embed with the transformer if ( hasattr(controlnet_model.config, "use_pos_embed") @@ -201,8 +203,7 @@ def __init__( ): pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer) controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device) - controlnet = SD3MultiControlNetModel(controlnet) - else: + elif isinstance(controlnet, SD3ControlNetModel): if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: pos_embed = controlnet._get_pos_embed_from_transformer(transformer) controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)