Skip to content

Commit b07b48d

Browse files
committed
Add controlnet_blocks_repeat to Flux forward
1 parent 7be937e commit b07b48d

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def forward(
402402
controlnet_block_samples=None,
403403
controlnet_single_block_samples=None,
404404
return_dict: bool = True,
405+
controlnet_blocks_repeat: bool = False,
405406
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
406407
"""
407408
The [`FluxTransformer2DModel`] forward method.
@@ -509,8 +510,8 @@ def custom_forward(*inputs):
509510
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
510511
interval_control = int(np.ceil(interval_control))
511512
# For Xlabs ControlNet.
512-
if len(controlnet_block_samples) == 2:
513-
hidden_states = hidden_states + controlnet_block_samples[index_block % 2]
513+
if controlnet_blocks_repeat:
514+
hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
514515
else:
515516
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
516517

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,7 @@ def __call__(
739739
)
740740

741741
# 3. Prepare control image
742+
controlnet_blocks_repeat = False
742743
num_channels_latents = self.transformer.config.in_channels // 4
743744
if isinstance(self.controlnet, FluxControlNetModel):
744745
control_image = self.prepare_image(
@@ -766,6 +767,8 @@ def __call__(
766767
height_control_image,
767768
width_control_image,
768769
)
770+
else:
771+
controlnet_blocks_repeat = True
769772

770773
# Here we ensure that `control_mode` has the same length as the control_image.
771774
if control_mode is not None:
@@ -926,6 +929,7 @@ def __call__(
926929
img_ids=latent_image_ids,
927930
joint_attention_kwargs=self.joint_attention_kwargs,
928931
return_dict=False,
932+
controlnet_blocks_repeat=controlnet_blocks_repeat,
929933
)[0]
930934

931935
# compute the previous noisy sample x_t -> x_t-1

0 commit comments

Comments
 (0)