diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 209aad93244e..118e8630ec8e 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -56,6 +56,8 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, + dual_attention_layers: Tuple[int, ...] = (), + qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels @@ -84,6 +86,8 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=False, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(num_layers) ] @@ -248,7 +252,7 @@ def from_transformer( config = transformer.config config["num_layers"] = num_layers or config.num_layers config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls(**config) + controlnet = cls.from_config(config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index f39a102c7256..1d3df99197bb 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn @@ -344,7 +345,8 @@ def custom_forward(*inputs): # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: - interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) + interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) + interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index aae1dc0ebcb0..90c253f783c6 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -15,6 +15,7 @@ import gc import unittest +from typing import Optional import numpy as np import pytest @@ -59,7 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) - def get_dummy_components(self): + def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, @@ -72,6 +73,7 @@ def get_dummy_components(self): caption_projection_dim=32, pooled_projection_dim=64, out_channels=8, + qk_norm=qk_norm, ) torch.manual_seed(0) @@ -79,7 +81,7 @@ def get_dummy_components(self): sample_size=32, patch_size=1, in_channels=8, - num_layers=1, + num_layers=num_controlnet_layers, attention_head_dim=8, num_attention_heads=4, joint_attention_dim=32,