Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,293 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
return pipeline, state


class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"]
model_name = "stable-diffusion-xl"

@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
("control_image", None),
("control_guidance_start", 0.0),
("control_guidance_end", 1.0),
("controlnet_conditioning_scale", 1.0),
("control_mode", 0),
("guess_mode", False),
("num_images_per_prompt", 1),
("guidance_scale", 5.0),
("guidance_rescale", 0.0),
("cross_attention_kwargs", None),
("generator", None),
("eta", 0.0),
("guider_kwargs", None),
]

@property
def intermediates_inputs(self) -> List[str]:
return [
"latents",
"batch_size",
"timesteps",
"num_inference_steps",
"prompt_embeds",
"negative_prompt_embeds",
"add_time_ids",
"negative_add_time_ids",
"pooled_prompt_embeds",
"negative_pooled_prompt_embeds",
"timestep_cond",
"mask",
"noise",
"image_latents",
"crops_coords",
]

@property
def intermediates_outputs(self) -> List[str]:
return ["latents"]

def __init__(self):
super().__init__()
self.components["guider"] = CFGGuider()
self.components["controlnet_guider"] = CFGGuider()
self.components["scheduler"] = None
self.components["unet"] = None
self.components["controlnet"] = None
control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False)
self.auxiliaries["control_image_processor"] = control_image_processor

@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
guidance_scale = state.get_input("guidance_scale")
guidance_rescale = state.get_input("guidance_rescale")
cross_attention_kwargs = state.get_input("cross_attention_kwargs")
guider_kwargs = state.get_input("guider_kwargs")
generator = state.get_input("generator")
eta = state.get_input("eta")
num_images_per_prompt = state.get_input("num_images_per_prompt")
# controlnet-specific inputs
control_image = state.get_input("control_image")
control_guidance_start = state.get_input("control_guidance_start")
control_guidance_end = state.get_input("control_guidance_end")
controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale")
control_mode = state.get_input("control_mode")
guess_mode = state.get_input("guess_mode")

batch_size = state.get_intermediate("batch_size")
latents = state.get_intermediate("latents")
timesteps = state.get_intermediate("timesteps")
num_inference_steps = state.get_intermediate("num_inference_steps")

prompt_embeds = state.get_intermediate("prompt_embeds")
negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds")
pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds")
negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds")
add_time_ids = state.get_intermediate("add_time_ids")
negative_add_time_ids = state.get_intermediate("negative_add_time_ids")

timestep_cond = state.get_intermediate("timestep_cond")

# inpainting
mask = state.get_intermediate("mask")
noise = state.get_intermediate("noise")
image_latents = state.get_intermediate("image_latents")
crops_coords = state.get_intermediate("crops_coords")

device = pipeline._execution_device

height, width = latents.shape[-2:]
height = height * pipeline.vae_scale_factor
width = width * pipeline.vae_scale_factor

# prepare controlnet inputs
controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet

# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]

global_pool_conditions = controlnet.config.global_pool_conditions
guess_mode = guess_mode or global_pool_conditions

num_control_type = controlnet.config.num_control_type

if not isinstance(control_image, list):
control_image = [control_image]

if not isinstance(control_mode, list):
control_mode = [control_mode]

if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")

control_type = [0 for _ in range(num_control_type)]
for control_idx in control_mode:
control_type[control_idx] = 1

control_type = torch.Tensor(control_type)

for idx, _ in enumerate(control_image):
control_image[idx] = pipeline.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
)
height, width = control_image[idx].shape[-2:]

controlnet_keep = []
for i in range(len(timesteps)):
controlnet_keep.append(
1.0
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
)

# Prepare conditional inputs for unet using the guider
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
guider_kwargs = guider_kwargs or {}
guider_kwargs = {
**guider_kwargs,
"disable_guidance": disable_guidance,
"guidance_scale": guidance_scale,
"guidance_rescale": guidance_rescale,
"batch_size": batch_size,
}
pipeline.guider.set_guider(pipeline, guider_kwargs)
prompt_embeds = pipeline.guider.prepare_input(
prompt_embeds,
negative_prompt_embeds,
)
add_time_ids = pipeline.guider.prepare_input(
add_time_ids,
negative_add_time_ids,
)
pooled_prompt_embeds = pipeline.guider.prepare_input(
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)

added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds,
"time_ids": add_time_ids,
}

# Prepare conditional inputs for controlnet using the guider
controlnet_disable_guidance = True if disable_guidance or guess_mode else False
controlnet_guider_kwargs = guider_kwargs or {}
controlnet_guider_kwargs = {
**controlnet_guider_kwargs,
"disable_guidance": controlnet_disable_guidance,
"guidance_scale": guidance_scale,
"guidance_rescale": guidance_rescale,
"batch_size": batch_size,
}
pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs)
controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds)
controlnet_added_cond_kwargs = {
"text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds),
"time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids),
}
for idx, _ in enumerate(control_image):
control_image[idx] = pipeline.controlnet_guider.prepare_input(control_image[idx], control_image[idx])

# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)

control_type = (
control_type.reshape(1, -1)
.to(device, dtype=prompt_embeds.dtype)
.repeat(batch_size * num_images_per_prompt * 2, 1)
)
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# prepare latents for unet using the guider
latent_model_input = pipeline.guider.prepare_input(latents, latents)

# prepare latents for controlnet using the guider
control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents)

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

down_block_res_samples, mid_block_res_sample = pipeline.controlnet(
pipeline.scheduler.scale_model_input(control_model_input, t),
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
)

# when we apply guidance for unet, but not for controlnet:
# add 0 to the unconditional batch
down_block_res_samples = pipeline.guider.prepare_input(
down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples]
)
mid_block_res_sample = pipeline.guider.prepare_input(
mid_block_res_sample, torch.zeros_like(mid_block_res_sample)
)

latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)

noise_pred = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)

if mask is not None and image_latents is not None:
init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0]
init_latents_proper = image_latents
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = pipeline.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep])
)

latents = (1 - init_mask) * init_latents_proper + init_mask * latents

if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()

pipeline.guider.reset_guider(pipeline)
pipeline.controlnet_guider.reset_guider(pipeline)
state.add_intermediate("latents", latents)

return pipeline, state

class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
expected_components = ["vae"]
model_name = "stable-diffusion-xl"
Expand Down