Skip to content
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,15 @@ def __call__(
width_control_image,
)

# set control mode
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# Here we ensure that `control_mode` has the same length as the control_image.
if not isinstance(control_mode, list):
control_mode = [control_mode]
if len(control_mode) > 1:
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or a list containing 1 `int`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1,1).expand(control_image.shape[0], 1)

#print(f"control image shape {control_image.shape}, mode shape {control_mode.shape}")

elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
Expand Down Expand Up @@ -785,16 +790,14 @@ def __call__(

control_image = control_images

# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# Here we ensure that `control_mode` has the same length as the control_image.
if not isinstance(control_mode, list) or len(control_mode) != len(control_image):
raise ValueError("For Multi-ControlNet, `control_mode` must be a list of the same " +
" length as the number of controlnets (control images) specified")
control_mode = torch.tensor(
[-1 if elem is None else elem for elem in control_mode]
)
control_mode = control_mode.view(-1,1).expand(len(control_image), 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this right though? for multi controlnet, we loop through each controlnet, control_image and control_mode, and control_image element and control_mode does not have same batch_size

for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following setup:

controlnet_union = FluxControlNetModel.from_pretrained(
    'InstantX/FLUX.1-dev-Controlnet-Union', torch_dtype=torch.bfloat16
)
controlnet_depth = FluxControlNetModel.from_pretrained(
    "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", torch_dtype=torch.bfloat16
)

multinet = FluxMultiControlNetModel([controlnet_union, controlnet_depth])
vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multinet, torch_dtype=torch.bfloat16
)

def pil_to_numpy(image):
    """to (c,h,w)"""
    return (np.array(image).astype(np.float32)/255.).swapaxes(1,2).swapaxes(0,1)

def pil_to_torch(image):
    return torch.from_numpy(pil_to_numpy(image)).float()

control_image = load_image(
    "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union-alpha/resolve/main/images/canny.jpg"
).resize((512,512))

If you're asking whether passing in an actual tensor for control_image would work (instead of PIL), i.e:

# this will have shape (3,512,512)
control_image = pil_to_torch(control_image)
# this will have shape (4,3,512,512), i.e. batch size of 4
control_image = control_image.unsqueeze(0).repeat(4,1,1,1)

pipe(
    # the controlnets here are [union, union]
    control_image=[control_image, control_image],  # each inner control_image is batched
    control_mode=[0, None], 
    controlnet_conditioning_scale=[1., 1.]
)

We will get an error at _pack_latents, again this is related to the fact that the pipeline completely ignores what the batch size of control_image. We initially mentioned fixing this in a future PR but it also depends on what other pipelines have this issue. I looked at SDXL controlnet and it seems to not have this issue. It wouldn't be unreasonable to pursue fixing it in this PR but we can do a new PR if you prefer that.

Also, I can clarify this works:

# This is a PIL image
control_image = obj["control_image"]

pipe(
    control_image=[control_image, control_image], 
    control_mode=[0, None], 
    controlnet_conditioning_scale=[1., 1.]
)


# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
Expand Down