Skip to content
6 changes: 5 additions & 1 deletion src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
out_channels: int = 16,
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
):
super().__init__()
default_out_channels = in_channels
Expand Down Expand Up @@ -84,6 +86,8 @@ def __init__(
num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -248,7 +252,7 @@ def from_transformer(
config = transformer.config
config["num_layers"] = num_layers or config.num_layers
config["extra_conditioning_channels"] = num_extra_conditioning_channels
controlnet = cls(**config)
controlnet = cls.from_config(config)

if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -344,7 +345,8 @@ def custom_forward(*inputs):

# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we making this change? it is not the same so a breaking change, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu Good question.

The revised code adapts the strategy used by ControlNet for Flux, introducing a significant improvement in flexibility. Here's why this change matters:

In the old code, the number of transformer layers is divisible by the number of ControlNet layers. For example, with SD3.5 Large, which has 38 transformer layers, there were only two valid options for the number of ControlNet layers: 2 and 19. Setting the number of ControlNet layers to anything else, such as 5, would cause the old code to crash.

However, the Flux ControlNet approach removes this restriction, allowing greater flexibility in choosing the number of layers. The revised logic essentially mirrors the Flux implementation, enabling more versatile configurations.

Importantly, the new code maintains compatibility with existing setups. If the number of transformer layers is divisible by the number of ControlNet layers, the interval_control remains unchanged, ensuring all previous configurations continue to function seamlessly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I think it's indeed better, I'm just wondering if it would cause issue for controlnet is trained with the current logic
cc @haofanwang here

Copy link
Contributor Author

@linjiapro linjiapro Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will cause any issue with the trained controlnet using old code before this PR.

The reason is that for the controlnet to be trained with the old code, the number of layers of the transformer has to be divisible by the number of layers of the controlnet, and the new logic after this PR does not change the behavior for the above scenario.

interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]

hidden_states = self.norm_out(hidden_states, temb)
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import gc
import unittest
from typing import Optional

import numpy as np
import pytest
Expand Down Expand Up @@ -59,7 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
)
batch_params = frozenset(["prompt", "negative_prompt"])

def get_dummy_components(self):
def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
Expand All @@ -72,14 +73,15 @@ def get_dummy_components(self):
caption_projection_dim=32,
pooled_projection_dim=64,
out_channels=8,
qk_norm=qk_norm,
)

torch.manual_seed(0)
controlnet = SD3ControlNetModel(
sample_size=32,
patch_size=1,
in_channels=8,
num_layers=1,
num_layers=num_controlnet_layers,
attention_head_dim=8,
num_attention_heads=4,
joint_attention_dim=32,
Expand Down
Loading