From 370f3828c24b7bbb48e861b3589171c168c2b5a2 Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Tue, 10 Sep 2024 15:20:20 +0000 Subject: [PATCH 1/4] flux controlnet mode to take into account batch size --- .../pipelines/flux/pipeline_flux_controlnet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 481994903d3f..4bbbcf821fe9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -748,9 +748,11 @@ def __call__( ) # set control mode + orig_mode_type = type(control_mode) if control_mode is not None: - control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1) + if orig_mode_type == int: + control_mode = control_mode.repeat(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -793,8 +795,10 @@ def __call__( control_mode_.append(-1) else: control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.reshape([-1, 1]) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.view(-1, 1) + else: + raise ValueError("For multi-controlnet, control_mode should be a list") # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 From c333d89b77760035cbd72dba939480c4bb5fc578 Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Wed, 18 Sep 2024 20:19:18 +0000 Subject: [PATCH 2/4] incorporate yiyixuxu's suggestions (cleaner logic) as well as clean up control mode handling for multi case --- .../flux/pipeline_flux_controlnet.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 4bbbcf821fe9..285402238774 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -747,12 +747,15 @@ def __call__( width_control_image, ) - # set control mode - orig_mode_type = type(control_mode) - if control_mode is not None: - control_mode = torch.tensor(control_mode).to(device, dtype=torch.long).view(-1,1) - if orig_mode_type == int: - control_mode = control_mode.repeat(control_image.shape[0], 1) + # Here we ensure that `control_mode` has the same length as the control_image. + if not isinstance(control_mode, list): + control_mode = [control_mode] + if len(control_mode) > 1: + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list containing 1 `int`") + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1) + + #print(f"control image shape {control_image.shape}, mode shape {control_mode.shape}") elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -787,18 +790,14 @@ def __call__( control_image = control_images - # set control mode - control_mode_ = [] - if isinstance(control_mode, list): - for cmode in control_mode: - if cmode is None: - control_mode_.append(-1) - else: - control_mode_.append(cmode) - control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) - control_mode = control_mode.view(-1, 1) - else: - raise ValueError("For multi-controlnet, control_mode should be a list") + # Here we ensure that `control_mode` has the same length as the control_image. + if not isinstance(control_mode, list) or len(control_mode) != len(control_image): + raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified") + control_mode = torch.tensor( + [-1 if elem is None else elem for elem in control_mode] + ) + control_mode = control_mode.view(-1,1).expand(len(control_image), 1) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 From 7c95f0b1de894ce18b2d981676893433eb7ff619 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Sep 2024 21:34:44 +0200 Subject: [PATCH 3/4] fix --- src/diffusers/models/controlnet_flux.py | 23 ++++++++------- .../flux/pipeline_flux_controlnet.py | 29 ++++++++++--------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py index 036e5654a98e..88ad49d2b776 100644 --- a/src/diffusers/models/controlnet_flux.py +++ b/src/diffusers/models/controlnet_flux.py @@ -502,16 +502,17 @@ def forward( control_block_samples = block_samples control_single_block_samples = single_block_samples else: - control_block_samples = [ - control_block_sample + block_sample - for control_block_sample, block_sample in zip(control_block_samples, block_samples) - ] - - control_single_block_samples = [ - control_single_block_sample + block_sample - for control_single_block_sample, block_sample in zip( - control_single_block_samples, single_block_samples - ) - ] + if block_samples is not None and control_block_samples is not None: + control_block_samples = [ + control_block_sample + block_sample + for control_block_sample, block_sample in zip(control_block_samples, block_samples) + ] + if single_block_samples is not None and control_single_block_samples is not None: + control_single_block_samples = [ + control_single_block_sample + block_sample + for control_single_block_sample, block_sample in zip( + control_single_block_samples, single_block_samples + ) + ] return control_block_samples, control_single_block_samples diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 285402238774..8d08f85d29f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -748,14 +748,11 @@ def __call__( ) # Here we ensure that `control_mode` has the same length as the control_image. - if not isinstance(control_mode, list): - control_mode = [control_mode] - if len(control_mode) > 1: - raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list containing 1 `int`") - control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1) - - #print(f"control image shape {control_image.shape}, mode shape {control_mode.shape}") + if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -791,13 +788,19 @@ def __call__( control_image = control_images # Here we ensure that `control_mode` has the same length as the control_image. - if not isinstance(control_mode, list) or len(control_mode) != len(control_image): + if isinstance(control_mode, list) and len(control_mode) != len(control_image): raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " + " length as the number of controlnets (control images) specified") - control_mode = torch.tensor( - [-1 if elem is None else elem for elem in control_mode] - ) - control_mode = control_mode.view(-1,1).expand(len(control_image), 1) + if not isinstance(control_mode, list): + control_mode = [control_mode] * len(control_image) + # set control mode + control_modes = [] + for cmode in control_mode: + if cmode is None: + cmode = -1 + control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) + control_modes.append(control_mode) + control_mode = control_modes # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 From 560449d8f3c529f2777c88da6305d25ad15687ba Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 23 Sep 2024 22:23:30 +0200 Subject: [PATCH 4/4] fix use_guidance when controlnet is a multi and does not have config --- .../flux/pipeline_flux_controlnet.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index bc7e7f0099a1..6c072c482020 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -752,7 +752,7 @@ def __call__( if not isinstance(control_mode, int): raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) - control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] @@ -789,11 +789,13 @@ def __call__( # Here we ensure that `control_mode` has the same length as the control_image. if isinstance(control_mode, list) and len(control_mode) != len(control_image): - raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " + - " length as the number of controlnets (control images) specified") + raise ValueError( + "For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified" + ) if not isinstance(control_mode, list): control_mode = [control_mode] * len(control_image) - # set control mode + # set control mode control_modes = [] for cmode in control_mode: if cmode is None: @@ -846,9 +848,12 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - guidance = ( - torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None - ) + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None # controlnet