Skip to content

Commit 66c2ab4

Browse files
committed
style changes as suggested
1 parent de3afaf commit 66c2ab4

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
```
7575
"""
7676

77+
7778
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
7879
def retrieve_latents(
7980
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -87,6 +88,7 @@ def retrieve_latents(
8788
else:
8889
raise AttributeError("Could not access latents of provided encoder_output")
8990

91+
9092
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9193
def retrieve_timesteps(
9294
scheduler,
@@ -149,7 +151,8 @@ def retrieve_timesteps(
149151

150152
class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin):
151153
r"""
152-
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for image-to-image generation using Stable Diffusion 3.
154+
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for image-to-image generation
155+
using Stable Diffusion 3.
153156
154157
Args:
155158
transformer ([`SD3Transformer2DModel`]):
@@ -643,16 +646,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
643646
return timesteps, num_inference_steps - t_start
644647

645648
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.prepare_latents
646-
def prepare_latents(
647-
self,
648-
image,
649-
timestep,
650-
batch_size,
651-
num_images_per_prompt,
652-
dtype,
653-
device,
654-
generator=None
655-
):
649+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
656650
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
657651
raise ValueError(
658652
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"

tests/pipelines/pag/test_pag_sd3_img2img.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
enable_full_determinism()
3939

40+
4041
class StableDiffusion3PAGImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
4142
pipeline_class = StableDiffusion3PAGImg2ImgPipeline
4243
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"}
@@ -176,7 +177,6 @@ def test_pag_inference(self):
176177
inputs = self.get_dummy_inputs(device)
177178
image = pipe_pag(**inputs).images
178179
image_slice = image[0, -3:, -3:, -1]
179-
print(f"{image_slice=}")
180180

181181
assert image.shape == (
182182
1,
@@ -185,12 +185,10 @@ def test_pag_inference(self):
185185
3,
186186
), f"the shape of the output image should be (1, 32, 32, 3) but got {image.shape}"
187187

188-
expected_slice = np.array([
189-
[0.7251651, 0.52043426, 0.5527822],
190-
[0.7089102, 0.62233330, 0.5923926],
191-
[0.4929751, 0.52322210, 0.5529656]
192-
])
193-
max_diff = np.abs(image_slice.flatten() - expected_slice.flatten()).max()
188+
expected_slice = np.array(
189+
[0.7251651, 0.52043426, 0.5527822, 0.7089102, 0.62233330, 0.5923926, 0.4929751, 0.52322210, 0.5529656]
190+
)
191+
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
194192
self.assertLessEqual(max_diff, 1e-3)
195193

196194

@@ -210,7 +208,9 @@ def tearDown(self):
210208
gc.collect()
211209
torch.cuda.empty_cache()
212210

213-
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0, guidance_scale=7.0, pag_scale=0.7):
211+
def get_inputs(
212+
self, device, generator_device="cpu", dtype=torch.float32, seed=0, guidance_scale=7.0, pag_scale=0.7
213+
):
214214
img_url = (
215215
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
216216
)
@@ -230,7 +230,9 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
230230
return inputs
231231

232232
def test_pag_cfg(self):
233-
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.17"])
233+
pipeline = AutoPipelineForImage2Image.from_pretrained(
234+
self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.17"]
235+
)
234236
pipeline.enable_model_cpu_offload()
235237
pipeline.set_progress_bar_config(disable=None)
236238

@@ -246,7 +248,9 @@ def test_pag_cfg(self):
246248
), f"output is different from expected, {image_slice.flatten()}"
247249

248250
def test_pag_uncond(self):
249-
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"])
251+
pipeline = AutoPipelineForImage2Image.from_pretrained(
252+
self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
253+
)
250254
pipeline.enable_model_cpu_offload()
251255
pipeline.set_progress_bar_config(disable=None)
252256

0 commit comments

Comments
 (0)