Skip to content

Commit 3e9a28a

Browse files
yiyixuxuAnghellia
andauthored
[authored by @Anghellia) Add support of Xlabs Controlnets huggingface#9638 (huggingface#9687)
* Add support of Xlabs Controlnets --------- Co-authored-by: Anzhella Pankratova <[email protected]>
1 parent 2ffbb88 commit 3e9a28a

File tree

3 files changed

+62
-32
lines changed

3 files changed

+62
-32
lines changed

src/diffusers/models/controlnet_flux.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..models.attention_processor import AttentionProcessor
2424
from ..models.modeling_utils import ModelMixin
2525
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26-
from .controlnet import BaseOutput, zero_module
26+
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
2727
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
2828
from .modeling_outputs import Transformer2DModelOutput
2929
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -55,6 +55,7 @@ def __init__(
5555
guidance_embeds: bool = False,
5656
axes_dims_rope: List[int] = [16, 56, 56],
5757
num_mode: int = None,
58+
conditioning_embedding_channels: int = None,
5859
):
5960
super().__init__()
6061
self.out_channels = in_channels
@@ -106,7 +107,14 @@ def __init__(
106107
if self.union:
107108
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
108109

109-
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
110+
if conditioning_embedding_channels is not None:
111+
self.input_hint_block = ControlNetConditioningEmbedding(
112+
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
113+
)
114+
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
115+
else:
116+
self.input_hint_block = None
117+
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
110118

111119
self.gradient_checkpointing = False
112120

@@ -269,6 +277,16 @@ def forward(
269277
)
270278
hidden_states = self.x_embedder(hidden_states)
271279

280+
if self.input_hint_block is not None:
281+
controlnet_cond = self.input_hint_block(controlnet_cond)
282+
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
283+
height = height_pw // self.config.patch_size
284+
width = width_pw // self.config.patch_size
285+
controlnet_cond = controlnet_cond.reshape(
286+
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
287+
)
288+
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
289+
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
272290
# add
273291
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
274292

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def forward(
402402
controlnet_block_samples=None,
403403
controlnet_single_block_samples=None,
404404
return_dict: bool = True,
405+
controlnet_blocks_repeat: bool = False,
405406
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
406407
"""
407408
The [`FluxTransformer2DModel`] forward method.
@@ -508,7 +509,13 @@ def custom_forward(*inputs):
508509
if controlnet_block_samples is not None:
509510
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
510511
interval_control = int(np.ceil(interval_control))
511-
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
512+
# For Xlabs ControlNet.
513+
if controlnet_blocks_repeat:
514+
hidden_states = (
515+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
516+
)
517+
else:
518+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
512519

513520
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
514521

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -754,19 +754,22 @@ def __call__(
754754
)
755755
height, width = control_image.shape[-2:]
756756

757-
# vae encode
758-
control_image = self.vae.encode(control_image).latent_dist.sample()
759-
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
760-
761-
# pack
762-
height_control_image, width_control_image = control_image.shape[2:]
763-
control_image = self._pack_latents(
764-
control_image,
765-
batch_size * num_images_per_prompt,
766-
num_channels_latents,
767-
height_control_image,
768-
width_control_image,
769-
)
757+
# xlab controlnet has a input_hint_block and instantx controlnet does not
758+
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
759+
if self.controlnet.input_hint_block is None:
760+
# vae encode
761+
control_image = self.vae.encode(control_image).latent_dist.sample()
762+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
763+
764+
# pack
765+
height_control_image, width_control_image = control_image.shape[2:]
766+
control_image = self._pack_latents(
767+
control_image,
768+
batch_size * num_images_per_prompt,
769+
num_channels_latents,
770+
height_control_image,
771+
width_control_image,
772+
)
770773

771774
# Here we ensure that `control_mode` has the same length as the control_image.
772775
if control_mode is not None:
@@ -777,8 +780,9 @@ def __call__(
777780

778781
elif isinstance(self.controlnet, FluxMultiControlNetModel):
779782
control_images = []
780-
781-
for control_image_ in control_image:
783+
# xlab controlnet has a input_hint_block and instantx controlnet does not
784+
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
785+
for i, control_image_ in enumerate(control_image):
782786
control_image_ = self.prepare_image(
783787
image=control_image_,
784788
width=width,
@@ -790,20 +794,20 @@ def __call__(
790794
)
791795
height, width = control_image_.shape[-2:]
792796

793-
# vae encode
794-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
795-
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
796-
797-
# pack
798-
height_control_image, width_control_image = control_image_.shape[2:]
799-
control_image_ = self._pack_latents(
800-
control_image_,
801-
batch_size * num_images_per_prompt,
802-
num_channels_latents,
803-
height_control_image,
804-
width_control_image,
805-
)
806-
797+
if self.controlnet.nets[0].input_hint_block is None:
798+
# vae encode
799+
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
800+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
801+
802+
# pack
803+
height_control_image, width_control_image = control_image_.shape[2:]
804+
control_image_ = self._pack_latents(
805+
control_image_,
806+
batch_size * num_images_per_prompt,
807+
num_channels_latents,
808+
height_control_image,
809+
width_control_image,
810+
)
807811
control_images.append(control_image_)
808812

809813
control_image = control_images
@@ -927,6 +931,7 @@ def __call__(
927931
img_ids=latent_image_ids,
928932
joint_attention_kwargs=self.joint_attention_kwargs,
929933
return_dict=False,
934+
controlnet_blocks_repeat=controlnet_blocks_repeat,
930935
)[0]
931936

932937
# compute the previous noisy sample x_t -> x_t-1

0 commit comments

Comments
 (0)