diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 11b71b1cbece..75ffc4f5c12a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -747,10 +747,15 @@ def __call__( width_control_image, ) - # set control mode - if control_mode is not None: - control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + # Here we ensure that `control_mode` has the same length as the control_image. + if not isinstance(control_mode, list): + control_mode = [control_mode] + if len(control_mode) > 1: + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list containing 1 `int`") + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1) + + #print(f"control image shape {control_image.shape}, mode shape {control_mode.shape}") elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -785,16 +790,14 @@ def __call__( control_image = control_images - # set control mode - control_mode_ = [] - if isinstance(control_mode, list): - for cmode in control_mode: - if cmode is None: - control_mode_.append(-1) - else: - control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + # Here we ensure that `control_mode` has the same length as the control_image. + if not isinstance(control_mode, list) or len(control_mode) != len(control_image): + raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified") + control_mode = torch.tensor( + [-1 if elem is None else elem for elem in control_mode] + ) + control_mode = control_mode.view(-1,1).expand(len(control_image), 1) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4