Skip to content

Commit 0de3c00

Browse files
committed
Create test_pag_controlnet_sd_img2img.py
1 parent 81683e2 commit 0de3c00

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
17+
18+
import gc
19+
import random
20+
import tempfile
21+
import unittest
22+
23+
import numpy as np
24+
import inspect
25+
import torch
26+
from PIL import Image
27+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
28+
29+
from diffusers import (
30+
AutoencoderKL,
31+
ControlNetModel,
32+
DDIMScheduler,
33+
StableDiffusionControlNetImg2ImgPipeline,
34+
StableDiffusionControlNetPAGImg2ImgPipeline,
35+
UNet2DConditionModel,
36+
)
37+
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
38+
from diffusers.utils import load_image
39+
from diffusers.utils.import_utils import is_xformers_available
40+
from diffusers.utils.testing_utils import (
41+
enable_full_determinism,
42+
floats_tensor,
43+
load_numpy,
44+
require_torch_gpu,
45+
slow,
46+
torch_device,
47+
)
48+
from diffusers.utils.torch_utils import randn_tensor
49+
50+
from ..pipeline_params import (
51+
IMAGE_TO_IMAGE_IMAGE_PARAMS,
52+
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
53+
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
54+
)
55+
from ..test_pipelines_common import (
56+
IPAdapterTesterMixin,
57+
PipelineKarrasSchedulerTesterMixin,
58+
PipelineLatentTesterMixin,
59+
PipelineTesterMixin,
60+
)
61+
62+
63+
enable_full_determinism()
64+
65+
class StableDiffusionControlNetPAGImg2ImgPipelineFastTests(
66+
IPAdapterTesterMixin,
67+
PipelineLatentTesterMixin,
68+
PipelineKarrasSchedulerTesterMixin,
69+
PipelineTesterMixin,
70+
unittest.TestCase,
71+
):
72+
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
73+
params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
74+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
75+
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"control_image"})
76+
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
77+
78+
def get_dummy_components(self):
79+
torch.manual_seed(0)
80+
unet = UNet2DConditionModel(
81+
block_out_channels=(4, 8),
82+
layers_per_block=2,
83+
sample_size=32,
84+
in_channels=4,
85+
out_channels=4,
86+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
87+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
88+
cross_attention_dim=32,
89+
norm_num_groups=1,
90+
)
91+
torch.manual_seed(0)
92+
controlnet = ControlNetModel(
93+
block_out_channels=(4, 8),
94+
layers_per_block=2,
95+
in_channels=4,
96+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
97+
cross_attention_dim=32,
98+
conditioning_embedding_out_channels=(16, 32),
99+
norm_num_groups=1,
100+
)
101+
torch.manual_seed(0)
102+
scheduler = DDIMScheduler(
103+
beta_start=0.00085,
104+
beta_end=0.012,
105+
beta_schedule="scaled_linear",
106+
clip_sample=False,
107+
set_alpha_to_one=False,
108+
)
109+
torch.manual_seed(0)
110+
vae = AutoencoderKL(
111+
block_out_channels=[4, 8],
112+
in_channels=3,
113+
out_channels=3,
114+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
115+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
116+
latent_channels=4,
117+
norm_num_groups=2,
118+
)
119+
torch.manual_seed(0)
120+
text_encoder_config = CLIPTextConfig(
121+
bos_token_id=0,
122+
eos_token_id=2,
123+
hidden_size=32,
124+
intermediate_size=37,
125+
layer_norm_eps=1e-05,
126+
num_attention_heads=4,
127+
num_hidden_layers=5,
128+
pad_token_id=1,
129+
vocab_size=1000,
130+
)
131+
text_encoder = CLIPTextModel(text_encoder_config)
132+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
133+
134+
components = {
135+
"unet": unet,
136+
"controlnet": controlnet,
137+
"scheduler": scheduler,
138+
"vae": vae,
139+
"text_encoder": text_encoder,
140+
"tokenizer": tokenizer,
141+
"safety_checker": None,
142+
"feature_extractor": None,
143+
"image_encoder": None,
144+
}
145+
return components
146+
147+
def get_dummy_inputs(self, device, seed=0):
148+
if str(device).startswith("mps"):
149+
generator = torch.manual_seed(seed)
150+
else:
151+
generator = torch.Generator(device=device).manual_seed(seed)
152+
153+
controlnet_embedder_scale_factor = 2
154+
control_image = randn_tensor(
155+
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
156+
generator=generator,
157+
device=torch.device(device),
158+
)
159+
image = floats_tensor(control_image.shape, rng=random.Random(seed)).to(device)
160+
image = image.cpu().permute(0, 2, 3, 1)[0]
161+
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
162+
inputs = {
163+
"prompt": "A painting of a squirrel eating a burger",
164+
"generator": generator,
165+
"num_inference_steps": 2,
166+
"guidance_scale": 6.0,
167+
"pag_scale": 3.0,
168+
"output_type": "np",
169+
"image": image,
170+
"control_image": control_image,
171+
}
172+
173+
return inputs
174+
175+
def test_pag_disable_enable(self):
176+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
177+
components = self.get_dummy_components()
178+
179+
# base pipeline (expect same output when pag is disabled)
180+
pipe_sd = StableDiffusionControlNetImg2ImgPipeline(**components)
181+
pipe_sd = pipe_sd.to(device)
182+
pipe_sd.set_progress_bar_config(disable=None)
183+
184+
inputs = self.get_dummy_inputs(device)
185+
del inputs["pag_scale"]
186+
assert (
187+
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
188+
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
189+
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
190+
191+
# pag disabled with pag_scale=0.0
192+
pipe_pag = self.pipeline_class(**components)
193+
pipe_pag = pipe_pag.to(device)
194+
pipe_pag.set_progress_bar_config(disable=None)
195+
196+
inputs = self.get_dummy_inputs(device)
197+
inputs["pag_scale"] = 0.0
198+
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
199+
200+
# pag enabled
201+
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
202+
pipe_pag = pipe_pag.to(device)
203+
pipe_pag.set_progress_bar_config(disable=None)
204+
205+
inputs = self.get_dummy_inputs(device)
206+
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
207+
208+
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
209+
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3

0 commit comments

Comments
 (0)