Skip to content

Commit 9ec5084

Browse files
StableDiffusionUpscalePipeline (#1396)
* StableDiffusionUpscalePipeline * fix a few things * make it better * fix image batching * run vae in fp32 * fix docstr * resize to mul of 64 * doc * remove safety_checker * add max_noise_level * fix Copied * begin tests * slow tests * default max_noise_level * remove kwargs * doc * fix * fix fast tests * fix fast tests * no sf * don't offload vae Co-authored-by: Patrick von Platen <[email protected]>
1 parent 02aa4ef commit 9ec5084

File tree

8 files changed

+896
-3
lines changed

8 files changed

+896
-3
lines changed

docs/source/api/pipelines/stable_diffusion.mdx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,10 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
9595
- __call__
9696
- enable_attention_slicing
9797
- disable_attention_slicing
98+
99+
100+
## StableDiffusionUpscalePipeline
101+
[[autodoc]] StableDiffusionUpscalePipeline
102+
- __call__
103+
- enable_attention_slicing
104+
- disable_attention_slicing

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
StableDiffusionInpaintPipelineLegacy,
7676
StableDiffusionPipeline,
7777
StableDiffusionPipelineSafe,
78+
StableDiffusionUpscalePipeline,
7879
VersatileDiffusionDualGuidedPipeline,
7980
VersatileDiffusionImageVariationPipeline,
8081
VersatileDiffusionPipeline,

src/diffusers/pipeline_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,9 @@ def load_module(name, value):
554554
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
555555

556556
if len(unused_kwargs) > 0:
557-
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
557+
logger.warning(
558+
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
559+
)
558560

559561
# import it here to avoid circular import
560562
from diffusers import pipelines
@@ -680,8 +682,8 @@ def load_module(name, value):
680682
@staticmethod
681683
def _get_signature_keys(obj):
682684
parameters = inspect.signature(obj.__init__).parameters
683-
required_parameters = {k: v for k, v in parameters.items() if v.default is not True}
684-
optional_parameters = set({k for k, v in parameters.items() if v.default is True})
685+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
686+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
685687
expected_modules = set(required_parameters.keys()) - set(["self"])
686688
return expected_modules, optional_parameters
687689

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
StableDiffusionInpaintPipeline,
2525
StableDiffusionInpaintPipelineLegacy,
2626
StableDiffusionPipeline,
27+
StableDiffusionUpscalePipeline,
2728
)
2829
from .stable_diffusion_safe import StableDiffusionPipelineSafe
2930
from .versatile_diffusion import (

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
4040
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
4141
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
4242
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
43+
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
4344
from .safety_checker import StableDiffusionSafetyChecker
4445

4546
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 551 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,21 @@ def from_pretrained(cls, *args, **kwargs):
154154
requires_backends(cls, ["torch", "transformers"])
155155

156156

157+
class StableDiffusionUpscalePipeline(metaclass=DummyObject):
158+
_backends = ["torch", "transformers"]
159+
160+
def __init__(self, *args, **kwargs):
161+
requires_backends(self, ["torch", "transformers"])
162+
163+
@classmethod
164+
def from_config(cls, *args, **kwargs):
165+
requires_backends(cls, ["torch", "transformers"])
166+
167+
@classmethod
168+
def from_pretrained(cls, *args, **kwargs):
169+
requires_backends(cls, ["torch", "transformers"])
170+
171+
157172
class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
158173
_backends = ["torch", "transformers"]
159174

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import random
18+
import unittest
19+
20+
import numpy as np
21+
import torch
22+
23+
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel
24+
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
25+
from diffusers.utils.testing_utils import require_torch_gpu
26+
from PIL import Image
27+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
28+
29+
from ...test_pipelines_common import PipelineTesterMixin
30+
31+
32+
torch.backends.cuda.matmul.allow_tf32 = False
33+
34+
35+
class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
36+
def tearDown(self):
37+
# clean up the VRAM after each test
38+
super().tearDown()
39+
gc.collect()
40+
torch.cuda.empty_cache()
41+
42+
@property
43+
def dummy_image(self):
44+
batch_size = 1
45+
num_channels = 3
46+
sizes = (32, 32)
47+
48+
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
49+
return image
50+
51+
@property
52+
def dummy_cond_unet_upscale(self):
53+
torch.manual_seed(0)
54+
model = UNet2DConditionModel(
55+
block_out_channels=(32, 32, 64),
56+
layers_per_block=2,
57+
sample_size=32,
58+
in_channels=7,
59+
out_channels=4,
60+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
61+
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
62+
cross_attention_dim=32,
63+
# SD2-specific config below
64+
attention_head_dim=8,
65+
use_linear_projection=True,
66+
only_cross_attention=(True, True, False),
67+
num_class_embeds=100,
68+
)
69+
return model
70+
71+
@property
72+
def dummy_vae(self):
73+
torch.manual_seed(0)
74+
model = AutoencoderKL(
75+
block_out_channels=[32, 32, 64],
76+
in_channels=3,
77+
out_channels=3,
78+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
79+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
80+
latent_channels=4,
81+
)
82+
return model
83+
84+
@property
85+
def dummy_text_encoder(self):
86+
torch.manual_seed(0)
87+
config = CLIPTextConfig(
88+
bos_token_id=0,
89+
eos_token_id=2,
90+
hidden_size=32,
91+
intermediate_size=37,
92+
layer_norm_eps=1e-05,
93+
num_attention_heads=4,
94+
num_hidden_layers=5,
95+
pad_token_id=1,
96+
vocab_size=1000,
97+
# SD2-specific config below
98+
hidden_act="gelu",
99+
projection_dim=512,
100+
)
101+
return CLIPTextModel(config)
102+
103+
def test_stable_diffusion_upscale(self):
104+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
105+
unet = self.dummy_cond_unet_upscale
106+
low_res_scheduler = DDPMScheduler()
107+
scheduler = DDIMScheduler(prediction_type="v_prediction")
108+
vae = self.dummy_vae
109+
text_encoder = self.dummy_text_encoder
110+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
111+
112+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
113+
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
114+
115+
# make sure here that pndm scheduler skips prk
116+
sd_pipe = StableDiffusionUpscalePipeline(
117+
unet=unet,
118+
low_res_scheduler=low_res_scheduler,
119+
scheduler=scheduler,
120+
vae=vae,
121+
text_encoder=text_encoder,
122+
tokenizer=tokenizer,
123+
max_noise_level=350,
124+
)
125+
sd_pipe = sd_pipe.to(device)
126+
sd_pipe.set_progress_bar_config(disable=None)
127+
128+
prompt = "A painting of a squirrel eating a burger"
129+
generator = torch.Generator(device=device).manual_seed(0)
130+
output = sd_pipe(
131+
[prompt],
132+
image=low_res_image,
133+
generator=generator,
134+
guidance_scale=6.0,
135+
noise_level=20,
136+
num_inference_steps=2,
137+
output_type="np",
138+
)
139+
140+
image = output.images
141+
142+
generator = torch.Generator(device=device).manual_seed(0)
143+
image_from_tuple = sd_pipe(
144+
[prompt],
145+
image=low_res_image,
146+
generator=generator,
147+
guidance_scale=6.0,
148+
noise_level=20,
149+
num_inference_steps=2,
150+
output_type="np",
151+
return_dict=False,
152+
)[0]
153+
154+
image_slice = image[0, -3:, -3:, -1]
155+
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
156+
157+
expected_height_width = low_res_image.size[0] * 4
158+
assert image.shape == (1, expected_height_width, expected_height_width, 3)
159+
expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606])
160+
161+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
162+
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
163+
164+
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
165+
def test_stable_diffusion_upscale_fp16(self):
166+
"""Test that stable diffusion upscale works with fp16"""
167+
unet = self.dummy_cond_unet_upscale
168+
low_res_scheduler = DDPMScheduler()
169+
scheduler = DDIMScheduler(prediction_type="v_prediction")
170+
vae = self.dummy_vae
171+
text_encoder = self.dummy_text_encoder
172+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
173+
174+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
175+
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
176+
177+
# put models in fp16, except vae as it overflows in fp16
178+
unet = unet.half()
179+
text_encoder = text_encoder.half()
180+
181+
# make sure here that pndm scheduler skips prk
182+
sd_pipe = StableDiffusionUpscalePipeline(
183+
unet=unet,
184+
low_res_scheduler=low_res_scheduler,
185+
scheduler=scheduler,
186+
vae=vae,
187+
text_encoder=text_encoder,
188+
tokenizer=tokenizer,
189+
max_noise_level=350,
190+
)
191+
sd_pipe = sd_pipe.to(torch_device)
192+
sd_pipe.set_progress_bar_config(disable=None)
193+
194+
prompt = "A painting of a squirrel eating a burger"
195+
generator = torch.Generator(device=torch_device).manual_seed(0)
196+
image = sd_pipe(
197+
[prompt],
198+
image=low_res_image,
199+
generator=generator,
200+
num_inference_steps=2,
201+
output_type="np",
202+
).images
203+
204+
expected_height_width = low_res_image.size[0] * 4
205+
assert image.shape == (1, expected_height_width, expected_height_width, 3)
206+
207+
208+
@slow
209+
@require_torch_gpu
210+
class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
211+
def tearDown(self):
212+
# clean up the VRAM after each test
213+
super().tearDown()
214+
gc.collect()
215+
torch.cuda.empty_cache()
216+
217+
def test_stable_diffusion_upscale_pipeline(self):
218+
image = load_image(
219+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
220+
"/sd2-upscale/low_res_cat.png"
221+
)
222+
expected_image = load_numpy(
223+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"
224+
"/upsampled_cat.npy"
225+
)
226+
227+
model_id = "stabilityai/stable-diffusion-x4-upscaler"
228+
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
229+
pipe.to(torch_device)
230+
pipe.set_progress_bar_config(disable=None)
231+
pipe.enable_attention_slicing()
232+
233+
prompt = "a cat sitting on a park bench"
234+
235+
generator = torch.Generator(device=torch_device).manual_seed(0)
236+
output = pipe(
237+
prompt=prompt,
238+
image=image,
239+
generator=generator,
240+
output_type="np",
241+
)
242+
image = output.images[0]
243+
244+
assert image.shape == (512, 512, 3)
245+
assert np.abs(expected_image - image).max() < 1e-3
246+
247+
def test_stable_diffusion_upscale_pipeline_fp16(self):
248+
image = load_image(
249+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
250+
"/sd2-upscale/low_res_cat.png"
251+
)
252+
expected_image = load_numpy(
253+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"
254+
"/upsampled_cat_fp16.npy"
255+
)
256+
257+
model_id = "stabilityai/stable-diffusion-x4-upscaler"
258+
pipe = StableDiffusionUpscalePipeline.from_pretrained(
259+
model_id,
260+
revision="fp16",
261+
torch_dtype=torch.float16,
262+
)
263+
pipe.to(torch_device)
264+
pipe.set_progress_bar_config(disable=None)
265+
pipe.enable_attention_slicing()
266+
267+
prompt = "a cat sitting on a park bench"
268+
269+
generator = torch.Generator(device=torch_device).manual_seed(0)
270+
output = pipe(
271+
prompt=prompt,
272+
image=image,
273+
generator=generator,
274+
output_type="np",
275+
)
276+
image = output.images[0]
277+
278+
assert image.shape == (512, 512, 3)
279+
assert np.abs(expected_image - image).max() < 5e-1
280+
281+
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
282+
torch.cuda.empty_cache()
283+
torch.cuda.reset_max_memory_allocated()
284+
torch.cuda.reset_peak_memory_stats()
285+
286+
image = load_image(
287+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
288+
"/sd2-upscale/low_res_cat.png"
289+
)
290+
291+
model_id = "stabilityai/stable-diffusion-x4-upscaler"
292+
pipe = StableDiffusionUpscalePipeline.from_pretrained(
293+
model_id,
294+
revision="fp16",
295+
torch_dtype=torch.float16,
296+
)
297+
pipe.to(torch_device)
298+
pipe.set_progress_bar_config(disable=None)
299+
pipe.enable_attention_slicing(1)
300+
pipe.enable_sequential_cpu_offload()
301+
302+
prompt = "a cat sitting on a park bench"
303+
304+
generator = torch.Generator(device=torch_device).manual_seed(0)
305+
_ = pipe(
306+
prompt=prompt,
307+
image=image,
308+
generator=generator,
309+
num_inference_steps=5,
310+
output_type="np",
311+
)
312+
313+
mem_bytes = torch.cuda.max_memory_allocated()
314+
# make sure that less than 2.65 GB is allocated
315+
assert mem_bytes < 2.65 * 10**9

0 commit comments

Comments
 (0)