Skip to content

Commit 41f9612

Browse files
author
Darshil Jariwala
committed
added copied from and removed unnecessary tests
1 parent 98ec6d8 commit 41f9612

File tree

2 files changed

+14
-61
lines changed

2 files changed

+14
-61
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def check_inputs(
667667
raise ValueError(
668668
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
669669
)
670-
670+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
671671
def prepare_latents(
672672
self,
673673
batch_size,
@@ -731,6 +731,7 @@ def prepare_latents(
731731

732732
return outputs
733733

734+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
734735
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
735736
if isinstance(generator, list):
736737
image_latents = [
@@ -745,6 +746,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
745746

746747
return image_latents
747748

749+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
748750
def prepare_mask_latents(
749751
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
750752
):
@@ -786,6 +788,9 @@ def prepare_mask_latents(
786788
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
787789
)
788790

791+
# star
792+
793+
789794
# aligning device to prevent device errors when concating it with the latent model input
790795
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
791796
return mask, masked_image_latents
@@ -996,23 +1001,8 @@ def __call__(
9961001
height = height or self.unet.config.sample_size * self.vae_scale_factor
9971002
width = width or self.unet.config.sample_size * self.vae_scale_factor
9981003
# to deal with lora scaling and other possible forward hooks
999-
1004+
10001005
# 1. Check inputs. Raise error if not correct
1001-
# prompt,
1002-
# image,
1003-
# mask_image,
1004-
# height,
1005-
# width,
1006-
# strength,
1007-
# callback_steps,
1008-
# output_type,
1009-
# negative_prompt=None,
1010-
# prompt_embeds=None,
1011-
# negative_prompt_embeds=None,
1012-
# ip_adapter_image=None,
1013-
# ip_adapter_image_embeds=None,
1014-
# callback_on_step_end_tensor_inputs=None,
1015-
# padding_mask_crop=None,
10161006
self.check_inputs(
10171007
prompt,
10181008
image,
@@ -1066,7 +1056,7 @@ def __call__(
10661056
clip_skip=self.clip_skip,
10671057
)
10681058

1069-
# 4. set timesteps
1059+
# 4. set timesteps
10701060
timesteps, num_inference_steps = retrieve_timesteps(
10711061
self.scheduler, num_inference_steps, device, timesteps, sigmas
10721062
)
@@ -1098,7 +1088,7 @@ def __call__(
10981088
)
10991089
init_image = init_image.to(dtype=torch.float32)
11001090

1101-
# 6. Prepare latent variables
1091+
# 6. Prepare latent variables
11021092
num_channels_latents = self.vae.config.latent_channels
11031093
num_channels_unet = self.unet.config.in_channels
11041094
return_image_latents = num_channels_unet == 4
@@ -1171,7 +1161,7 @@ def __call__(
11711161
raise ValueError(
11721162
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
11731163
)
1174-
# 8.1 Prepare extra step kwargs.
1164+
# 9 Prepare extra step kwargs.
11751165
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
11761166

11771167
# For classifier free guidance, we need to do two forward passes.
@@ -1210,22 +1200,22 @@ def __call__(
12101200

12111201

12121202

1213-
# 6.1 Add image embeds for IP-Adapter
1203+
# 9.1 Add image embeds for IP-Adapter
12141204
added_cond_kwargs = (
12151205
{"image_embeds": ip_adapter_image_embeds}
12161206
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
12171207
else None
12181208
)
12191209

1220-
# 6.2 Optionally get Guidance Scale Embedding
1210+
# 9.2 Optionally get Guidance Scale Embedding
12211211
timestep_cond = None
12221212
if self.unet.config.time_cond_proj_dim is not None:
12231213
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
12241214
timestep_cond = self.get_guidance_scale_embedding(
12251215
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
12261216
).to(device=device, dtype=latents.dtype)
12271217

1228-
# 7. Denoising loop
1218+
# 10. Denoising loop
12291219
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
12301220

12311221

tests/pipelines/pag/test_pag_sd_inpaint.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
7878
time_cond_proj_dim=time_cond_proj_dim,
7979
layers_per_block=2,
8080
sample_size=32,
81-
in_channels=9,
81+
in_channels=4,
8282
out_channels=4,
8383
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
8484
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
@@ -148,43 +148,6 @@ def get_dummy_inputs(self, device, seed=0):
148148
}
149149
return inputs
150150

151-
152-
def test_pag_disable_enable(self):
153-
device = "cpu" # ensure determinism for the device-dependent torch.Generator
154-
components = self.get_dummy_components()
155-
156-
# base pipeline (expect same output when pag is disabled)
157-
pipe_sd = StableDiffusionInpaintPipeline(**components)
158-
pipe_sd = pipe_sd.to(device)
159-
pipe_sd.set_progress_bar_config(disable=None)
160-
161-
inputs = self.get_dummy_inputs(device)
162-
del inputs["pag_scale"]
163-
assert (
164-
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
165-
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
166-
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
167-
168-
# pag disabled with pag_scale=0.0
169-
pipe_pag = self.pipeline_class(**components)
170-
pipe_pag = pipe_pag.to(device)
171-
pipe_pag.set_progress_bar_config(disable=None)
172-
173-
inputs = self.get_dummy_inputs(device)
174-
inputs["pag_scale"] = 0.0
175-
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
176-
177-
# pag enabled
178-
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
179-
pipe_pag = pipe_pag.to(device)
180-
pipe_pag.set_progress_bar_config(disable=None)
181-
182-
inputs = self.get_dummy_inputs(device)
183-
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
184-
185-
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
186-
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
187-
188151
def test_pag_applied_layers(self):
189152
device = "cpu" # ensure determinism for the device-dependent torch.Generator
190153
components = self.get_dummy_components()

0 commit comments

Comments
 (0)