Skip to content

Commit 7a34832

Browse files
authored
[modular] Stable Diffusion XL ControlNet Union (#10509)
StableDiffusionXLControlNetUnionDenoiseStep
1 parent e973de6 commit 7a34832

File tree

1 file changed

+287
-0
lines changed

1 file changed

+287
-0
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,293 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
15821582
return pipeline, state
15831583

15841584

1585+
class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock):
1586+
expected_components = ["unet", "controlnet", "scheduler", "guider", "controlnet_guider"]
1587+
model_name = "stable-diffusion-xl"
1588+
1589+
@property
1590+
def inputs(self) -> List[Tuple[str, Any]]:
1591+
return [
1592+
("control_image", None),
1593+
("control_guidance_start", 0.0),
1594+
("control_guidance_end", 1.0),
1595+
("controlnet_conditioning_scale", 1.0),
1596+
("control_mode", 0),
1597+
("guess_mode", False),
1598+
("num_images_per_prompt", 1),
1599+
("guidance_scale", 5.0),
1600+
("guidance_rescale", 0.0),
1601+
("cross_attention_kwargs", None),
1602+
("generator", None),
1603+
("eta", 0.0),
1604+
("guider_kwargs", None),
1605+
]
1606+
1607+
@property
1608+
def intermediates_inputs(self) -> List[str]:
1609+
return [
1610+
"latents",
1611+
"batch_size",
1612+
"timesteps",
1613+
"num_inference_steps",
1614+
"prompt_embeds",
1615+
"negative_prompt_embeds",
1616+
"add_time_ids",
1617+
"negative_add_time_ids",
1618+
"pooled_prompt_embeds",
1619+
"negative_pooled_prompt_embeds",
1620+
"timestep_cond",
1621+
"mask",
1622+
"noise",
1623+
"image_latents",
1624+
"crops_coords",
1625+
]
1626+
1627+
@property
1628+
def intermediates_outputs(self) -> List[str]:
1629+
return ["latents"]
1630+
1631+
def __init__(self):
1632+
super().__init__()
1633+
self.components["guider"] = CFGGuider()
1634+
self.components["controlnet_guider"] = CFGGuider()
1635+
self.components["scheduler"] = None
1636+
self.components["unet"] = None
1637+
self.components["controlnet"] = None
1638+
control_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False)
1639+
self.auxiliaries["control_image_processor"] = control_image_processor
1640+
1641+
@torch.no_grad()
1642+
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
1643+
guidance_scale = state.get_input("guidance_scale")
1644+
guidance_rescale = state.get_input("guidance_rescale")
1645+
cross_attention_kwargs = state.get_input("cross_attention_kwargs")
1646+
guider_kwargs = state.get_input("guider_kwargs")
1647+
generator = state.get_input("generator")
1648+
eta = state.get_input("eta")
1649+
num_images_per_prompt = state.get_input("num_images_per_prompt")
1650+
# controlnet-specific inputs
1651+
control_image = state.get_input("control_image")
1652+
control_guidance_start = state.get_input("control_guidance_start")
1653+
control_guidance_end = state.get_input("control_guidance_end")
1654+
controlnet_conditioning_scale = state.get_input("controlnet_conditioning_scale")
1655+
control_mode = state.get_input("control_mode")
1656+
guess_mode = state.get_input("guess_mode")
1657+
1658+
batch_size = state.get_intermediate("batch_size")
1659+
latents = state.get_intermediate("latents")
1660+
timesteps = state.get_intermediate("timesteps")
1661+
num_inference_steps = state.get_intermediate("num_inference_steps")
1662+
1663+
prompt_embeds = state.get_intermediate("prompt_embeds")
1664+
negative_prompt_embeds = state.get_intermediate("negative_prompt_embeds")
1665+
pooled_prompt_embeds = state.get_intermediate("pooled_prompt_embeds")
1666+
negative_pooled_prompt_embeds = state.get_intermediate("negative_pooled_prompt_embeds")
1667+
add_time_ids = state.get_intermediate("add_time_ids")
1668+
negative_add_time_ids = state.get_intermediate("negative_add_time_ids")
1669+
1670+
timestep_cond = state.get_intermediate("timestep_cond")
1671+
1672+
# inpainting
1673+
mask = state.get_intermediate("mask")
1674+
noise = state.get_intermediate("noise")
1675+
image_latents = state.get_intermediate("image_latents")
1676+
crops_coords = state.get_intermediate("crops_coords")
1677+
1678+
device = pipeline._execution_device
1679+
1680+
height, width = latents.shape[-2:]
1681+
height = height * pipeline.vae_scale_factor
1682+
width = width * pipeline.vae_scale_factor
1683+
1684+
# prepare controlnet inputs
1685+
controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
1686+
1687+
# align format for control guidance
1688+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1689+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1690+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1691+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1692+
1693+
global_pool_conditions = controlnet.config.global_pool_conditions
1694+
guess_mode = guess_mode or global_pool_conditions
1695+
1696+
num_control_type = controlnet.config.num_control_type
1697+
1698+
if not isinstance(control_image, list):
1699+
control_image = [control_image]
1700+
1701+
if not isinstance(control_mode, list):
1702+
control_mode = [control_mode]
1703+
1704+
if len(control_image) != len(control_mode):
1705+
raise ValueError("Expected len(control_image) == len(control_type)")
1706+
1707+
control_type = [0 for _ in range(num_control_type)]
1708+
for control_idx in control_mode:
1709+
control_type[control_idx] = 1
1710+
1711+
control_type = torch.Tensor(control_type)
1712+
1713+
for idx, _ in enumerate(control_image):
1714+
control_image[idx] = pipeline.prepare_control_image(
1715+
image=control_image[idx],
1716+
width=width,
1717+
height=height,
1718+
batch_size=batch_size * num_images_per_prompt,
1719+
num_images_per_prompt=num_images_per_prompt,
1720+
device=device,
1721+
dtype=controlnet.dtype,
1722+
crops_coords=crops_coords,
1723+
)
1724+
height, width = control_image[idx].shape[-2:]
1725+
1726+
controlnet_keep = []
1727+
for i in range(len(timesteps)):
1728+
controlnet_keep.append(
1729+
1.0
1730+
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1731+
)
1732+
1733+
# Prepare conditional inputs for unet using the guider
1734+
# adding default guider arguments: disable_guidance, guidance_scale, guidance_rescale
1735+
disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
1736+
guider_kwargs = guider_kwargs or {}
1737+
guider_kwargs = {
1738+
**guider_kwargs,
1739+
"disable_guidance": disable_guidance,
1740+
"guidance_scale": guidance_scale,
1741+
"guidance_rescale": guidance_rescale,
1742+
"batch_size": batch_size,
1743+
}
1744+
pipeline.guider.set_guider(pipeline, guider_kwargs)
1745+
prompt_embeds = pipeline.guider.prepare_input(
1746+
prompt_embeds,
1747+
negative_prompt_embeds,
1748+
)
1749+
add_time_ids = pipeline.guider.prepare_input(
1750+
add_time_ids,
1751+
negative_add_time_ids,
1752+
)
1753+
pooled_prompt_embeds = pipeline.guider.prepare_input(
1754+
pooled_prompt_embeds,
1755+
negative_pooled_prompt_embeds,
1756+
)
1757+
1758+
added_cond_kwargs = {
1759+
"text_embeds": pooled_prompt_embeds,
1760+
"time_ids": add_time_ids,
1761+
}
1762+
1763+
# Prepare conditional inputs for controlnet using the guider
1764+
controlnet_disable_guidance = True if disable_guidance or guess_mode else False
1765+
controlnet_guider_kwargs = guider_kwargs or {}
1766+
controlnet_guider_kwargs = {
1767+
**controlnet_guider_kwargs,
1768+
"disable_guidance": controlnet_disable_guidance,
1769+
"guidance_scale": guidance_scale,
1770+
"guidance_rescale": guidance_rescale,
1771+
"batch_size": batch_size,
1772+
}
1773+
pipeline.controlnet_guider.set_guider(pipeline, controlnet_guider_kwargs)
1774+
controlnet_prompt_embeds = pipeline.controlnet_guider.prepare_input(prompt_embeds)
1775+
controlnet_added_cond_kwargs = {
1776+
"text_embeds": pipeline.controlnet_guider.prepare_input(pooled_prompt_embeds),
1777+
"time_ids": pipeline.controlnet_guider.prepare_input(add_time_ids),
1778+
}
1779+
for idx, _ in enumerate(control_image):
1780+
control_image[idx] = pipeline.controlnet_guider.prepare_input(control_image[idx], control_image[idx])
1781+
1782+
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1783+
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
1784+
num_warmup_steps = max(len(timesteps) - num_inference_steps * pipeline.scheduler.order, 0)
1785+
1786+
control_type = (
1787+
control_type.reshape(1, -1)
1788+
.to(device, dtype=prompt_embeds.dtype)
1789+
.repeat(batch_size * num_images_per_prompt * 2, 1)
1790+
)
1791+
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
1792+
for i, t in enumerate(timesteps):
1793+
# prepare latents for unet using the guider
1794+
latent_model_input = pipeline.guider.prepare_input(latents, latents)
1795+
1796+
# prepare latents for controlnet using the guider
1797+
control_model_input = pipeline.controlnet_guider.prepare_input(latents, latents)
1798+
1799+
if isinstance(controlnet_keep[i], list):
1800+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1801+
else:
1802+
controlnet_cond_scale = controlnet_conditioning_scale
1803+
if isinstance(controlnet_cond_scale, list):
1804+
controlnet_cond_scale = controlnet_cond_scale[0]
1805+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
1806+
1807+
down_block_res_samples, mid_block_res_sample = pipeline.controlnet(
1808+
pipeline.scheduler.scale_model_input(control_model_input, t),
1809+
t,
1810+
encoder_hidden_states=controlnet_prompt_embeds,
1811+
controlnet_cond=control_image,
1812+
control_type=control_type,
1813+
control_type_idx=control_mode,
1814+
conditioning_scale=cond_scale,
1815+
guess_mode=guess_mode,
1816+
added_cond_kwargs=controlnet_added_cond_kwargs,
1817+
return_dict=False,
1818+
)
1819+
1820+
# when we apply guidance for unet, but not for controlnet:
1821+
# add 0 to the unconditional batch
1822+
down_block_res_samples = pipeline.guider.prepare_input(
1823+
down_block_res_samples, [torch.zeros_like(d) for d in down_block_res_samples]
1824+
)
1825+
mid_block_res_sample = pipeline.guider.prepare_input(
1826+
mid_block_res_sample, torch.zeros_like(mid_block_res_sample)
1827+
)
1828+
1829+
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
1830+
1831+
noise_pred = pipeline.unet(
1832+
latent_model_input,
1833+
t,
1834+
encoder_hidden_states=prompt_embeds,
1835+
timestep_cond=timestep_cond,
1836+
cross_attention_kwargs=cross_attention_kwargs,
1837+
added_cond_kwargs=added_cond_kwargs,
1838+
down_block_additional_residuals=down_block_res_samples,
1839+
mid_block_additional_residual=mid_block_res_sample,
1840+
return_dict=False,
1841+
)[0]
1842+
# perform guidance
1843+
noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents)
1844+
# compute the previous noisy sample x_t -> x_t-1
1845+
latents_dtype = latents.dtype
1846+
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1847+
if latents.dtype != latents_dtype:
1848+
if torch.backends.mps.is_available():
1849+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1850+
latents = latents.to(latents_dtype)
1851+
1852+
if mask is not None and image_latents is not None:
1853+
init_mask = pipeline.guider._maybe_split_prepared_input(mask)[0]
1854+
init_latents_proper = image_latents
1855+
if i < len(timesteps) - 1:
1856+
noise_timestep = timesteps[i + 1]
1857+
init_latents_proper = pipeline.scheduler.add_noise(
1858+
init_latents_proper, noise, torch.tensor([noise_timestep])
1859+
)
1860+
1861+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1862+
1863+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
1864+
progress_bar.update()
1865+
1866+
pipeline.guider.reset_guider(pipeline)
1867+
pipeline.controlnet_guider.reset_guider(pipeline)
1868+
state.add_intermediate("latents", latents)
1869+
1870+
return pipeline, state
1871+
15851872
class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
15861873
expected_components = ["vae"]
15871874
model_name = "stable-diffusion-xl"

0 commit comments

Comments
 (0)