Skip to content

Commit 5c10e68

Browse files
authored
Add SD2 inpainting integration tests (#1412)
SD2 inpainting integration tests
1 parent d50e321 commit 5c10e68

File tree

1 file changed

+345
-0
lines changed

1 file changed

+345
-0
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import random
18+
import unittest
19+
20+
import numpy as np
21+
import torch
22+
23+
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
24+
from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device
25+
from diffusers.utils.testing_utils import require_torch_gpu
26+
from PIL import Image
27+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
28+
29+
from ...test_pipelines_common import PipelineTesterMixin
30+
31+
32+
torch.backends.cuda.matmul.allow_tf32 = False
33+
34+
35+
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
36+
def tearDown(self):
37+
# clean up the VRAM after each test
38+
super().tearDown()
39+
gc.collect()
40+
torch.cuda.empty_cache()
41+
42+
@property
43+
def dummy_image(self):
44+
batch_size = 1
45+
num_channels = 3
46+
sizes = (32, 32)
47+
48+
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
49+
return image
50+
51+
@property
52+
def dummy_cond_unet_inpaint(self):
53+
torch.manual_seed(0)
54+
model = UNet2DConditionModel(
55+
block_out_channels=(32, 64),
56+
layers_per_block=2,
57+
sample_size=32,
58+
in_channels=9,
59+
out_channels=4,
60+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
61+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
62+
cross_attention_dim=32,
63+
# SD2-specific config below
64+
attention_head_dim=(2, 4, 8, 8),
65+
use_linear_projection=True,
66+
)
67+
return model
68+
69+
@property
70+
def dummy_vae(self):
71+
torch.manual_seed(0)
72+
model = AutoencoderKL(
73+
block_out_channels=[32, 64],
74+
in_channels=3,
75+
out_channels=3,
76+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
77+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
78+
latent_channels=4,
79+
)
80+
return model
81+
82+
@property
83+
def dummy_text_encoder(self):
84+
torch.manual_seed(0)
85+
config = CLIPTextConfig(
86+
bos_token_id=0,
87+
eos_token_id=2,
88+
hidden_size=32,
89+
intermediate_size=37,
90+
layer_norm_eps=1e-05,
91+
num_attention_heads=4,
92+
num_hidden_layers=5,
93+
pad_token_id=1,
94+
vocab_size=1000,
95+
# SD2-specific config below
96+
hidden_act="gelu",
97+
projection_dim=512,
98+
)
99+
return CLIPTextModel(config)
100+
101+
@property
102+
def dummy_extractor(self):
103+
def extract(*args, **kwargs):
104+
class Out:
105+
def __init__(self):
106+
self.pixel_values = torch.ones([0])
107+
108+
def to(self, device):
109+
self.pixel_values.to(device)
110+
return self
111+
112+
return Out()
113+
114+
return extract
115+
116+
def test_stable_diffusion_inpaint(self):
117+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
118+
unet = self.dummy_cond_unet_inpaint
119+
scheduler = PNDMScheduler(skip_prk_steps=True)
120+
vae = self.dummy_vae
121+
text_encoder = self.dummy_text_encoder
122+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
123+
124+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
125+
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
126+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
127+
128+
# make sure here that pndm scheduler skips prk
129+
sd_pipe = StableDiffusionInpaintPipeline(
130+
unet=unet,
131+
scheduler=scheduler,
132+
vae=vae,
133+
text_encoder=text_encoder,
134+
tokenizer=tokenizer,
135+
safety_checker=None,
136+
feature_extractor=None,
137+
)
138+
sd_pipe = sd_pipe.to(device)
139+
sd_pipe.set_progress_bar_config(disable=None)
140+
141+
prompt = "A painting of a squirrel eating a burger"
142+
generator = torch.Generator(device=device).manual_seed(0)
143+
output = sd_pipe(
144+
[prompt],
145+
generator=generator,
146+
guidance_scale=6.0,
147+
num_inference_steps=2,
148+
output_type="np",
149+
image=init_image,
150+
mask_image=mask_image,
151+
)
152+
153+
image = output.images
154+
155+
generator = torch.Generator(device=device).manual_seed(0)
156+
image_from_tuple = sd_pipe(
157+
[prompt],
158+
generator=generator,
159+
guidance_scale=6.0,
160+
num_inference_steps=2,
161+
output_type="np",
162+
image=init_image,
163+
mask_image=mask_image,
164+
return_dict=False,
165+
)[0]
166+
167+
image_slice = image[0, -3:, -3:, -1]
168+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
169+
170+
assert image.shape == (1, 64, 64, 3)
171+
expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476])
172+
173+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
174+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
175+
176+
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
177+
def test_stable_diffusion_inpaint_fp16(self):
178+
"""Test that stable diffusion inpaint works with fp16"""
179+
unet = self.dummy_cond_unet_inpaint
180+
scheduler = PNDMScheduler(skip_prk_steps=True)
181+
vae = self.dummy_vae
182+
text_encoder = self.dummy_text_encoder
183+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
184+
185+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
186+
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
187+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
188+
189+
# put models in fp16
190+
unet = unet.half()
191+
vae = vae.half()
192+
text_encoder = text_encoder.half()
193+
194+
# make sure here that pndm scheduler skips prk
195+
sd_pipe = StableDiffusionInpaintPipeline(
196+
unet=unet,
197+
scheduler=scheduler,
198+
vae=vae,
199+
text_encoder=text_encoder,
200+
tokenizer=tokenizer,
201+
safety_checker=None,
202+
feature_extractor=None,
203+
)
204+
sd_pipe = sd_pipe.to(torch_device)
205+
sd_pipe.set_progress_bar_config(disable=None)
206+
207+
prompt = "A painting of a squirrel eating a burger"
208+
generator = torch.Generator(device=torch_device).manual_seed(0)
209+
image = sd_pipe(
210+
[prompt],
211+
generator=generator,
212+
num_inference_steps=2,
213+
output_type="np",
214+
image=init_image,
215+
mask_image=mask_image,
216+
).images
217+
218+
assert image.shape == (1, 64, 64, 3)
219+
220+
221+
# @slow
222+
@require_torch_gpu
223+
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
224+
def tearDown(self):
225+
# clean up the VRAM after each test
226+
super().tearDown()
227+
gc.collect()
228+
torch.cuda.empty_cache()
229+
230+
def test_stable_diffusion_inpaint_pipeline(self):
231+
init_image = load_image(
232+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
233+
"/sd2-inpaint/init_image.png"
234+
)
235+
mask_image = load_image(
236+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
237+
)
238+
expected_image = load_numpy(
239+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
240+
"/yellow_cat_sitting_on_a_park_bench.npy"
241+
)
242+
243+
model_id = "stabilityai/stable-diffusion-2-inpainting"
244+
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
245+
pipe.to(torch_device)
246+
pipe.set_progress_bar_config(disable=None)
247+
pipe.enable_attention_slicing()
248+
249+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
250+
251+
generator = torch.Generator(device=torch_device).manual_seed(0)
252+
output = pipe(
253+
prompt=prompt,
254+
image=init_image,
255+
mask_image=mask_image,
256+
generator=generator,
257+
output_type="np",
258+
)
259+
image = output.images[0]
260+
261+
assert image.shape == (512, 512, 3)
262+
assert np.abs(expected_image - image).max() < 1e-3
263+
264+
def test_stable_diffusion_inpaint_pipeline_fp16(self):
265+
init_image = load_image(
266+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
267+
"/sd2-inpaint/init_image.png"
268+
)
269+
mask_image = load_image(
270+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
271+
)
272+
expected_image = load_numpy(
273+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
274+
"/yellow_cat_sitting_on_a_park_bench_fp16.npy"
275+
)
276+
277+
model_id = "stabilityai/stable-diffusion-2-inpainting"
278+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
279+
model_id,
280+
revision="fp16",
281+
torch_dtype=torch.float16,
282+
safety_checker=None,
283+
)
284+
pipe.to(torch_device)
285+
pipe.set_progress_bar_config(disable=None)
286+
pipe.enable_attention_slicing()
287+
288+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
289+
290+
generator = torch.Generator(device=torch_device).manual_seed(0)
291+
output = pipe(
292+
prompt=prompt,
293+
image=init_image,
294+
mask_image=mask_image,
295+
generator=generator,
296+
output_type="np",
297+
)
298+
image = output.images[0]
299+
300+
assert image.shape == (512, 512, 3)
301+
assert np.abs(expected_image - image).max() < 5e-1
302+
303+
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
304+
torch.cuda.empty_cache()
305+
torch.cuda.reset_max_memory_allocated()
306+
torch.cuda.reset_peak_memory_stats()
307+
308+
init_image = load_image(
309+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
310+
"/sd2-inpaint/init_image.png"
311+
)
312+
mask_image = load_image(
313+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
314+
)
315+
316+
model_id = "stabilityai/stable-diffusion-2-inpainting"
317+
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
318+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
319+
model_id,
320+
safety_checker=None,
321+
scheduler=pndm,
322+
device_map="auto",
323+
revision="fp16",
324+
torch_dtype=torch.float16,
325+
)
326+
pipe.to(torch_device)
327+
pipe.set_progress_bar_config(disable=None)
328+
pipe.enable_attention_slicing(1)
329+
pipe.enable_sequential_cpu_offload()
330+
331+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
332+
333+
generator = torch.Generator(device=torch_device).manual_seed(0)
334+
_ = pipe(
335+
prompt=prompt,
336+
image=init_image,
337+
mask_image=mask_image,
338+
generator=generator,
339+
num_inference_steps=5,
340+
output_type="np",
341+
)
342+
343+
mem_bytes = torch.cuda.max_memory_allocated()
344+
# make sure that less than 2.65 GB is allocated
345+
assert mem_bytes < 2.65 * 10**9

0 commit comments

Comments
 (0)