From a0d199a67a1de636a4a824074b024c433b2394d4 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Wed, 23 Oct 2024 15:29:21 -0700 Subject: [PATCH 01/15] improve control net index --- src/diffusers/models/transformers/transformer_sd3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index b28350b8ed9c..5816fada97d3 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -345,7 +345,10 @@ def custom_forward(*inputs): # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) - hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + hidden_states_layer_index = index_block // interval_control + if hidden_states_layer_index >= len(block_controlnet_hidden_states): + hidden_states_layer_index = len(block_controlnet_hidden_states) - 1 + hidden_states = hidden_states + block_controlnet_hidden_states[hidden_states_layer_index] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From 48b4b624ca4341e8355a4babb631a07d96e70b53 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Wed, 23 Oct 2024 23:31:48 -0700 Subject: [PATCH 02/15] wip --- .gitignore | 1 + tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 15617d5fdc74..f90cf4d4091e 100644 --- a/.gitignore +++ b/.gitignore @@ -102,6 +102,7 @@ venv/ ENV/ env.bak/ venv.bak/ +myenv/ # Spyder project settings .spyderproject diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 74cb56e0337a..e73a61d5fc95 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -77,7 +77,7 @@ def get_dummy_components(self): sample_size=32, patch_size=1, in_channels=8, - num_layers=1, + num_layers=3, attention_head_dim=8, num_attention_heads=4, joint_attention_dim=32, From d1a1ebef3f2dbfd4717c99e1d54c221d09648d3a Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Thu, 24 Oct 2024 12:58:04 -0700 Subject: [PATCH 03/15] wip --- src/diffusers/models/controlnet_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 43b52a645a0d..3c70ea133f8c 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -248,7 +248,7 @@ def from_transformer( config = transformer.config config["num_layers"] = num_layers or config.num_layers config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls(**config) + controlnet = cls.from_config(**config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) From bd32b2b18da0d48075beb4599ba2a6fd7ad0002b Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Thu, 24 Oct 2024 13:01:00 -0700 Subject: [PATCH 04/15] wip --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index f90cf4d4091e..15617d5fdc74 100644 --- a/.gitignore +++ b/.gitignore @@ -102,7 +102,6 @@ venv/ ENV/ env.bak/ venv.bak/ -myenv/ # Spyder project settings .spyderproject From a7ffec7f482b37074da1251c7282c3635b25f20e Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Thu, 24 Oct 2024 13:07:31 -0700 Subject: [PATCH 05/15] wip --- tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index e73a61d5fc95..e3ee51f6c692 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -56,7 +56,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ] ) batch_params = frozenset(["prompt", "negative_prompt"]) - + def get_dummy_components(self): torch.manual_seed(0) transformer = SD3Transformer2DModel( @@ -73,11 +73,12 @@ def get_dummy_components(self): ) torch.manual_seed(0) + num_controlnet_layers = 3 controlnet = SD3ControlNetModel( sample_size=32, patch_size=1, in_channels=8, - num_layers=3, + num_layers=num_controlnet_layers, attention_head_dim=8, num_attention_heads=4, joint_attention_dim=32, From 235b800e53e509b81efdc4e6586527e4be1a0eab Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Thu, 24 Oct 2024 13:20:44 -0700 Subject: [PATCH 06/15] wip --- tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index e3ee51f6c692..cd8f06947a85 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -57,7 +57,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) - def get_dummy_components(self): + def get_dummy_components(self, num_controlnet_layers: int = 3): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, @@ -73,7 +73,6 @@ def get_dummy_components(self): ) torch.manual_seed(0) - num_controlnet_layers = 3 controlnet = SD3ControlNetModel( sample_size=32, patch_size=1, From 50b4db91e925466b767f47eb40eb9db2b64e62d1 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Fri, 25 Oct 2024 22:59:02 -0700 Subject: [PATCH 07/15] wip --- src/diffusers/models/controlnet_sd3.py | 2 ++ tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 3c70ea133f8c..0799ffe1b243 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -56,6 +56,7 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, + qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels @@ -84,6 +85,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=False, + qk_norm=qk_norm, ) for i in range(num_layers) ] diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index cd8f06947a85..fc3d4e3395b9 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -18,6 +18,7 @@ import numpy as np import torch +from typing import Optional from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( @@ -57,7 +58,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) - def get_dummy_components(self, num_controlnet_layers: int = 3): + def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = None): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, @@ -70,6 +71,7 @@ def get_dummy_components(self, num_controlnet_layers: int = 3): caption_projection_dim=32, pooled_projection_dim=64, out_channels=8, + qk_norm=qk_norm, ) torch.manual_seed(0) From 933ecf379f6b85b98172221b6bbd9395dceb6897 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Fri, 25 Oct 2024 23:09:13 -0700 Subject: [PATCH 08/15] add rms_norm to controlnet test --- tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index fc3d4e3395b9..e6a418661995 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -58,7 +58,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) - def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = None): + def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, From 9d33417f56f92bac02e70612a5ba6ed9c5963093 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Tue, 29 Oct 2024 11:46:51 -0700 Subject: [PATCH 09/15] wip --- src/diffusers/models/controlnet_sd3.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index 0799ffe1b243..e9b6a8192e11 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -56,6 +56,9 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, + dual_attention_layers: Tuple[ + int, ... + ] = (), qk_norm: Optional[str] = None, ): super().__init__() @@ -86,6 +89,7 @@ def __init__( attention_head_dim=self.config.attention_head_dim, context_pre_only=False, qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(num_layers) ] From cd3069c619ac809386ca17989376fbf98e82210e Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Tue, 29 Oct 2024 11:53:09 -0700 Subject: [PATCH 10/15] wip --- src/diffusers/models/controlnet_sd3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index e9b6a8192e11..8b0433b76ddc 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -56,6 +56,7 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, + context_pre_only_last_layer: bool = False, dual_attention_layers: Tuple[ int, ... ] = (), @@ -87,7 +88,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=self.config.attention_head_dim, - context_pre_only=False, + context_pre_only=context_pre_only_last_layer and i == num_layers - 1, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, ) From e40bd61d1f3b374b8520e8604e601ce189981081 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Fri, 15 Nov 2024 00:53:01 -0800 Subject: [PATCH 11/15] wip --- src/diffusers/models/transformers/transformer_sd3.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 5816fada97d3..1cb45463571d 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +import numpy as np from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -344,11 +345,9 @@ def custom_forward(*inputs): # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: - interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) - hidden_states_layer_index = index_block // interval_control - if hidden_states_layer_index >= len(block_controlnet_hidden_states): - hidden_states_layer_index = len(block_controlnet_hidden_states) - 1 - hidden_states = hidden_states + block_controlnet_hidden_states[hidden_states_layer_index] + interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From d8006c3c1b2b29b4b351045bf219f325fe45c8c7 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Fri, 15 Nov 2024 01:08:22 -0800 Subject: [PATCH 12/15] format --- .../models/controlnets/controlnet_sd3.py | 123 +++++------------- 1 file changed, 33 insertions(+), 90 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 6cf5dabf9403..a01e9c038240 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -21,19 +21,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import ( - USE_PEFT_BACKEND, - is_torch_version, - logging, - scale_lora_layers, - unscale_lora_layers, -) +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention import JointTransformerBlock -from ..attention_processor import ( - Attention, - AttentionProcessor, - FusedJointAttnProcessor2_0, -) +from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -48,9 +38,7 @@ class SD3ControlNetOutput(BaseOutput): controlnet_block_samples: Tuple[torch.Tensor] -class SD3ControlNetModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin -): +class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True @register_to_config @@ -68,14 +56,14 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, - dual_attention_layers: Tuple[int, ...] = (), + dual_attention_layers: Tuple[ + int, ... + ] = (), qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels - self.out_channels = ( - out_channels if out_channels is not None else default_out_channels - ) + self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = PatchEmbed( @@ -126,9 +114,7 @@ def __init__( self.gradient_checkpointing = False # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking( - self, chunk_size: Optional[int] = None, dim: int = 0 - ) -> None: + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). @@ -147,9 +133,7 @@ def enable_forward_chunking( # By default chunk size is 1 chunk_size = chunk_size or 1 - def fn_recursive_feed_forward( - module: torch.nn.Module, chunk_size: int, dim: int - ): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) @@ -170,11 +154,7 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors( - name: str, - module: torch.nn.Module, - processors: Dict[str, AttentionProcessor], - ): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() @@ -189,9 +169,7 @@ def fn_recursive_add_processors( return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -241,9 +219,7 @@ def fuse_qkv_projections(self): for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): - raise ValueError( - "`fuse_qkv_projections()` is not supported for models having added KV projections." - ) + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors @@ -273,28 +249,18 @@ def _set_gradient_checkpointing(self, module, value=False): @classmethod def from_transformer( - cls, - transformer, - num_layers=12, - num_extra_conditioning_channels=1, - load_weights_from_transformer=True, + cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True ): config = transformer.config config["num_layers"] = num_layers or config.num_layers config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls.from_config(**config) + controlnet = cls(**config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) - controlnet.time_text_embed.load_state_dict( - transformer.time_text_embed.state_dict() - ) - controlnet.context_embedder.load_state_dict( - transformer.context_embedder.state_dict() - ) - controlnet.transformer_blocks.load_state_dict( - transformer.transformer_blocks.state_dict(), strict=False - ) + controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) + controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) + controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) @@ -349,17 +315,12 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if ( - joint_attention_kwargs is not None - and joint_attention_kwargs.get("scale", None) is not None - ): + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - hidden_states = self.pos_embed( - hidden_states - ) # takes care of adding positional embeddings too. + hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) @@ -380,41 +341,29 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = ( - {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - ) - encoder_hidden_states, hidden_states = ( - torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) block_res_samples = block_res_samples + (hidden_states,) controlnet_block_res_samples = () - for block_res_sample, controlnet_block in zip( - block_res_samples, self.controlnet_blocks - ): + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): block_res_sample = controlnet_block(block_res_sample) - controlnet_block_res_samples = controlnet_block_res_samples + ( - block_res_sample, - ) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) # 6. scaling - controlnet_block_res_samples = [ - sample * conditioning_scale for sample in controlnet_block_res_samples - ] + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer @@ -423,9 +372,7 @@ def custom_forward(*inputs): if not return_dict: return (controlnet_block_res_samples,) - return SD3ControlNetOutput( - controlnet_block_samples=controlnet_block_res_samples - ) + return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) class SD3MultiControlNetModel(ModelMixin): @@ -456,9 +403,7 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[SD3ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate( - zip(controlnet_cond, conditioning_scale, self.nets) - ): + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): block_samples = controlnet( hidden_states=hidden_states, timestep=timestep, @@ -476,9 +421,7 @@ def forward( else: control_block_samples = [ control_block_sample + block_sample - for control_block_sample, block_sample in zip( - control_block_samples[0], block_samples[0] - ) + for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) ] control_block_samples = (tuple(control_block_samples),) From b4983cb2ca4573d70d2d49ba6ba324288b26a52d Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Fri, 15 Nov 2024 01:09:32 -0800 Subject: [PATCH 13/15] wip --- src/diffusers/models/controlnet_sd3.py | 2 -- src/diffusers/models/controlnets/controlnet_sd3.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py index b865ec0906da..5e70559e9ac4 100644 --- a/src/diffusers/models/controlnet_sd3.py +++ b/src/diffusers/models/controlnet_sd3.py @@ -34,10 +34,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - class SD3MultiControlNetModel(SD3MultiControlNetModel): def __init__(self, *args, **kwargs): deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead." deprecate("SD3MultiControlNetModel", "0.34", deprecation_message) super().__init__(*args, **kwargs) - diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index a01e9c038240..f213e8153f68 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -254,7 +254,7 @@ def from_transformer( config = transformer.config config["num_layers"] = num_layers or config.num_layers config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls(**config) + controlnet = cls.from_config(**config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) From 5f560ca0db2dacb23ca84b769a3f2fd4003d5400 Mon Sep 17 00:00:00 2001 From: linjiapro Date: Tue, 19 Nov 2024 23:06:29 -0800 Subject: [PATCH 14/15] Update src/diffusers/models/controlnets/controlnet_sd3.py Co-authored-by: YiYi Xu --- src/diffusers/models/controlnets/controlnet_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index f213e8153f68..58a4ae7ad910 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -254,7 +254,7 @@ def from_transformer( config = transformer.config config["num_layers"] = num_layers or config.num_layers config["extra_conditioning_channels"] = num_extra_conditioning_channels - controlnet = cls.from_config(**config) + controlnet = cls.from_config(config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) From 570558b0e7f329fa8b2fc820fff9957c7ea2b46b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 20 Nov 2024 23:32:24 +0100 Subject: [PATCH 15/15] style --- src/diffusers/models/controlnets/controlnet_sd3.py | 4 +--- src/diffusers/models/transformers/transformer_sd3.py | 2 +- tests/pipelines/controlnet_sd3/test_controlnet_sd3.py | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 58a4ae7ad910..118e8630ec8e 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -56,9 +56,7 @@ def __init__( out_channels: int = 16, pos_embed_max_size: int = 96, extra_conditioning_channels: int = 0, - dual_attention_layers: Tuple[ - int, ... - ] = (), + dual_attention_layers: Tuple[int, ...] = (), qk_norm: Optional[str] = None, ): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index e6feff05e53a..1d3df99197bb 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -15,9 +15,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn -import numpy as np from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index a9142b673688..90c253f783c6 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -15,11 +15,11 @@ import gc import unittest +from typing import Optional import numpy as np import pytest import torch -from typing import Optional from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import ( @@ -59,7 +59,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ] ) batch_params = frozenset(["prompt", "negative_prompt"]) - + def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"): torch.manual_seed(0) transformer = SD3Transformer2DModel(