Skip to content

Commit 2e0f5c8

Browse files
committed
start to add inpaint
1 parent 1d63306 commit 2e0f5c8

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,82 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
544544
return pipeline, state
545545

546546

547+
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
548+
expected_components = ["vae", "scheduler"]
549+
model_name = "stable-diffusion-xl"
550+
551+
@property
552+
def inputs(self) -> List[Tuple[str, Any]]:
553+
return [
554+
("height", None),
555+
("width", None),
556+
("generator", None),
557+
("latents", None),
558+
("num_images_per_prompt", 1),
559+
("device", None),
560+
("dtype", None),
561+
("image", None),
562+
("denoising_start", None),
563+
]
564+
565+
@property
566+
def intermediates_inputs(self) -> List[str]:
567+
return ["batch_size", "latent_timestep", "prompt_embeds"]
568+
569+
@property
570+
def intermediates_outputs(self) -> List[str]:
571+
return ["latents"]
572+
573+
def __init__(self):
574+
super().__init__()
575+
self.auxiliaries["image_processor"] = VaeImageProcessor()
576+
self.components["vae"] = None
577+
self.components["scheduler"] = None
578+
579+
@torch.no_grad()
580+
def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState:
581+
latents = state.get_input("latents")
582+
num_images_per_prompt = state.get_input("num_images_per_prompt")
583+
generator = state.get_input("generator")
584+
device = state.get_input("device")
585+
dtype = state.get_input("dtype")
586+
587+
# image to image only
588+
image = state.get_input("image")
589+
denoising_start = state.get_input("denoising_start")
590+
591+
batch_size = state.get_intermediate("batch_size")
592+
prompt_embeds = state.get_intermediate("prompt_embeds")
593+
# image to image only
594+
latent_timestep = state.get_intermediate("latent_timestep")
595+
596+
if dtype is None and prompt_embeds is not None:
597+
dtype = prompt_embeds.dtype
598+
elif dtype is None:
599+
dtype = pipeline.vae.dtype
600+
601+
if device is None:
602+
device = pipeline._execution_device
603+
604+
image = pipeline.image_processor.preprocess(image)
605+
add_noise = True if denoising_start is None else False
606+
if latents is None:
607+
latents = pipeline.prepare_latents_img2img(
608+
image,
609+
latent_timestep,
610+
batch_size,
611+
num_images_per_prompt,
612+
dtype,
613+
device,
614+
generator,
615+
add_noise,
616+
)
617+
618+
state.add_intermediate("latents", latents)
619+
620+
return pipeline, state
621+
622+
547623
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
548624
expected_components = ["vae", "scheduler"]
549625
model_name = "stable-diffusion-xl"
@@ -2026,6 +2102,100 @@ def prepare_latents_img2img(
20262102

20272103
return latents
20282104

2105+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents
2106+
def prepare_latents_inpaint(
2107+
self,
2108+
batch_size,
2109+
num_channels_latents,
2110+
height,
2111+
width,
2112+
dtype,
2113+
device,
2114+
generator,
2115+
latents=None,
2116+
image=None,
2117+
timestep=None,
2118+
is_strength_max=True,
2119+
add_noise=True,
2120+
return_noise=False,
2121+
return_image_latents=False,
2122+
):
2123+
shape = (
2124+
batch_size,
2125+
num_channels_latents,
2126+
int(height) // self.vae_scale_factor,
2127+
int(width) // self.vae_scale_factor,
2128+
)
2129+
if isinstance(generator, list) and len(generator) != batch_size:
2130+
raise ValueError(
2131+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
2132+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
2133+
)
2134+
2135+
if (image is None or timestep is None) and not is_strength_max:
2136+
raise ValueError(
2137+
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
2138+
"However, either the image or the noise timestep has not been provided."
2139+
)
2140+
2141+
if image.shape[1] == 4:
2142+
image_latents = image.to(device=device, dtype=dtype)
2143+
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
2144+
elif return_image_latents or (latents is None and not is_strength_max):
2145+
image = image.to(device=device, dtype=dtype)
2146+
image_latents = self._encode_vae_image(image=image, generator=generator)
2147+
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
2148+
2149+
if latents is None and add_noise:
2150+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
2151+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
2152+
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
2153+
# if pure noise then scale the initial latents by the Scheduler's init sigma
2154+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
2155+
elif add_noise:
2156+
noise = latents.to(device)
2157+
latents = noise * self.scheduler.init_noise_sigma
2158+
else:
2159+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
2160+
latents = image_latents.to(device)
2161+
2162+
outputs = (latents,)
2163+
2164+
if return_noise:
2165+
outputs += (noise,)
2166+
2167+
if return_image_latents:
2168+
outputs += (image_latents,)
2169+
2170+
return outputs
2171+
2172+
2173+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image
2174+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
2175+
dtype = image.dtype
2176+
if self.vae.config.force_upcast:
2177+
image = image.float()
2178+
self.vae.to(dtype=torch.float32)
2179+
2180+
if isinstance(generator, list):
2181+
image_latents = [
2182+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
2183+
for i in range(image.shape[0])
2184+
]
2185+
image_latents = torch.cat(image_latents, dim=0)
2186+
else:
2187+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
2188+
2189+
if self.vae.config.force_upcast:
2190+
self.vae.to(dtype)
2191+
2192+
image_latents = image_latents.to(dtype)
2193+
image_latents = self.vae.config.scaling_factor * image_latents
2194+
2195+
return image_latents
2196+
2197+
2198+
20292199
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
20302200
def prepare_extra_step_kwargs(self, generator, eta):
20312201
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature

0 commit comments

Comments
 (0)