Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,53 @@ def __call__(
width = latent_width * self.vae_scale_factor

elif isinstance(self.controlnet, SD3MultiControlNetModel):
raise NotImplementedError("MultiControlNetModel is not supported for SD3ControlNetInpaintingPipeline.")
# Normalize inputs to lists matching number of control nets
num_cn = len(self.controlnet.nets)

if not isinstance(control_image, (list, tuple)):
control_images = [control_image] * num_cn
else:
control_images = list(control_image)

if not isinstance(control_mask, (list, tuple)):
control_masks = [control_mask] * num_cn
else:
control_masks = list(control_mask)

if len(control_images) != num_cn:
raise ValueError(
f"Expected {num_cn} control images for SD3MultiControlNetModel, got {len(control_images)}."
)
if len(control_masks) != num_cn:
raise ValueError(
f"Expected {num_cn} control masks for SD3MultiControlNetModel, got {len(control_masks)}."
)

# Prepare per-control inpainting conditions
prepared_controls = []
first_latent_size = None
for img_i, msk_i in zip(control_images, control_masks):
ctrl = self.prepare_image_with_mask(
image=img_i,
mask=msk_i,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=False,
)
if first_latent_size is None:
first_latent_size = ctrl.shape[-2:]
prepared_controls.append(ctrl)

latent_height, latent_width = first_latent_size
height = latent_height * self.vae_scale_factor
width = latent_width * self.vae_scale_factor

control_image = prepared_controls
else:
assert False

Expand All @@ -1128,6 +1174,10 @@ def __call__(
else:
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds

# Ensure conditioning scale broadcast for multi-control
if isinstance(self.controlnet, SD3MultiControlNetModel) and not isinstance(controlnet_conditioning_scale, list):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
25 changes: 24 additions & 1 deletion tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
SD3Transformer2DModel,
StableDiffusion3ControlNetInpaintingPipeline,
)
from diffusers.models import SD3ControlNetModel
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
Expand Down Expand Up @@ -201,3 +201,26 @@ def test_controlnet_inpaint_sd3(self):
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
pass

def test_controlnet_inpaint_sd3_multi_control(self):
components = self.get_dummy_components()
# Duplicate the single controlnet into a MultiControlNet for smoke test
cn = components["controlnet"]
components["controlnet"] = SD3MultiControlNetModel([cn, cn])

sd_pipe = StableDiffusion3ControlNetInpaintingPipeline(**components)
sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
# Provide lists for multi-control
inputs["control_image"] = [inputs["control_image"], inputs["control_image"]]
inputs["control_mask"] = [inputs["control_mask"], inputs["control_mask"]]
inputs["controlnet_conditioning_scale"] = [1.0, 0.5]
inputs["num_inference_steps"] = 2

output = sd_pipe(**inputs)
image = output.images

# Shape check only (deterministic slice check not required for multi-control smoke test)
assert image.shape == (1, 32, 32, 3)