Skip to content

Commit 1b89ac1

Browse files
committed
prepare_latents_img2img pipeline method -> function, maybe do the same for others?
1 parent eb94150 commit 1b89ac1

File tree

1 file changed

+83
-85
lines changed

1 file changed

+83
-85
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,86 @@ def retrieve_latents(
127127
raise AttributeError("Could not access latents of provided encoder_output")
128128

129129

130+
def prepare_latents_img2img(vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True):
131+
132+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
133+
raise ValueError(
134+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
135+
)
136+
137+
image = image.to(device=device, dtype=dtype)
138+
139+
batch_size = batch_size * num_images_per_prompt
140+
141+
if image.shape[1] == 4:
142+
init_latents = image
143+
144+
else:
145+
latents_mean = latents_std = None
146+
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
147+
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
148+
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
149+
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
150+
# make sure the VAE is in float32 mode, as it overflows in float16
151+
if vae.config.force_upcast:
152+
image = image.float()
153+
vae.to(dtype=torch.float32)
154+
155+
if isinstance(generator, list) and len(generator) != batch_size:
156+
raise ValueError(
157+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159+
)
160+
161+
elif isinstance(generator, list):
162+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
163+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
164+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
165+
raise ValueError(
166+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
167+
)
168+
169+
init_latents = [
170+
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i])
171+
for i in range(batch_size)
172+
]
173+
init_latents = torch.cat(init_latents, dim=0)
174+
else:
175+
init_latents = retrieve_latents(vae.encode(image), generator=generator)
176+
177+
if vae.config.force_upcast:
178+
vae.to(dtype)
179+
180+
init_latents = init_latents.to(dtype)
181+
if latents_mean is not None and latents_std is not None:
182+
latents_mean = latents_mean.to(device=device, dtype=dtype)
183+
latents_std = latents_std.to(device=device, dtype=dtype)
184+
init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
185+
else:
186+
init_latents = vae.config.scaling_factor * init_latents
187+
188+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
189+
# expand init_latents for batch_size
190+
additional_image_per_prompt = batch_size // init_latents.shape[0]
191+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
192+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
193+
raise ValueError(
194+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
195+
)
196+
else:
197+
init_latents = torch.cat([init_latents], dim=0)
198+
199+
if add_noise:
200+
shape = init_latents.shape
201+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
202+
# get latents
203+
init_latents = scheduler.add_noise(init_latents, noise, timestep)
204+
205+
latents = init_latents
206+
207+
return latents
208+
209+
130210
class StableDiffusionXLInputStep(PipelineBlock):
131211
model_name = "stable-diffusion-xl"
132212

@@ -751,89 +831,6 @@ def intermediates_inputs(self) -> List[InputParam]:
751831
def intermediates_outputs(self) -> List[OutputParam]:
752832
return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")]
753833

754-
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents with self -> components
755-
# YiYi TODO: refactor using _encode_vae_image
756-
@staticmethod
757-
def prepare_latents_img2img(
758-
components, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
759-
):
760-
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
761-
raise ValueError(
762-
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
763-
)
764-
765-
image = image.to(device=device, dtype=dtype)
766-
767-
batch_size = batch_size * num_images_per_prompt
768-
769-
if image.shape[1] == 4:
770-
init_latents = image
771-
772-
else:
773-
latents_mean = latents_std = None
774-
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
775-
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
776-
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
777-
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
778-
# make sure the VAE is in float32 mode, as it overflows in float16
779-
if components.vae.config.force_upcast:
780-
image = image.float()
781-
components.vae.to(dtype=torch.float32)
782-
783-
if isinstance(generator, list) and len(generator) != batch_size:
784-
raise ValueError(
785-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
786-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
787-
)
788-
789-
elif isinstance(generator, list):
790-
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
791-
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
792-
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
793-
raise ValueError(
794-
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
795-
)
796-
797-
init_latents = [
798-
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
799-
for i in range(batch_size)
800-
]
801-
init_latents = torch.cat(init_latents, dim=0)
802-
else:
803-
init_latents = retrieve_latents(components.vae.encode(image), generator=generator)
804-
805-
if components.vae.config.force_upcast:
806-
components.vae.to(dtype)
807-
808-
init_latents = init_latents.to(dtype)
809-
if latents_mean is not None and latents_std is not None:
810-
latents_mean = latents_mean.to(device=device, dtype=dtype)
811-
latents_std = latents_std.to(device=device, dtype=dtype)
812-
init_latents = (init_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
813-
else:
814-
init_latents = components.vae.config.scaling_factor * init_latents
815-
816-
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
817-
# expand init_latents for batch_size
818-
additional_image_per_prompt = batch_size // init_latents.shape[0]
819-
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
820-
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
821-
raise ValueError(
822-
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
823-
)
824-
else:
825-
init_latents = torch.cat([init_latents], dim=0)
826-
827-
if add_noise:
828-
shape = init_latents.shape
829-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
830-
# get latents
831-
init_latents = components.scheduler.add_noise(init_latents, noise, timestep)
832-
833-
latents = init_latents
834-
835-
return latents
836-
837834
@torch.no_grad()
838835
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
839836
block_state = self.get_block_state(state)
@@ -842,8 +839,9 @@ def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineSt
842839
block_state.device = components._execution_device
843840
block_state.add_noise = True if block_state.denoising_start is None else False
844841
if block_state.latents is None:
845-
block_state.latents = self.prepare_latents_img2img(
846-
components,
842+
block_state.latents = prepare_latents_img2img(
843+
components.vae,
844+
components.scheduler,
847845
block_state.image_latents,
848846
block_state.latent_timestep,
849847
block_state.batch_size,

0 commit comments

Comments
 (0)