Skip to content

Commit 02aa4ef

Browse files
authored
Add tests for Stable Diffusion 2 V-prediction 768x768 (#1420)
1 parent 8faa822 commit 02aa4ef

File tree

2 files changed

+495
-28
lines changed

2 files changed

+495
-28
lines changed

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from diffusers.utils import load_numpy, slow, torch_device
3636
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
37-
from transformers import CLIPFeatureExtractor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
37+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
3838

3939
from ...test_pipelines_common import PipelineTesterMixin
4040

@@ -100,21 +100,6 @@ def dummy_text_encoder(self):
100100
)
101101
return CLIPTextModel(config)
102102

103-
@property
104-
def dummy_extractor(self):
105-
def extract(*args, **kwargs):
106-
class Out:
107-
def __init__(self):
108-
self.pixel_values = torch.ones([0])
109-
110-
def to(self, device):
111-
self.pixel_values.to(device)
112-
return self
113-
114-
return Out()
115-
116-
return extract
117-
118103
def test_save_pretrained_from_pretrained(self):
119104
device = "cpu" # ensure determinism for the device-dependent torch.Generator
120105
unet = self.dummy_cond_unet
@@ -129,7 +114,6 @@ def test_save_pretrained_from_pretrained(self):
129114
vae = self.dummy_vae
130115
bert = self.dummy_text_encoder
131116
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
132-
feature_extractor = CLIPFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-clip")
133117

134118
# make sure here that pndm scheduler skips prk
135119
sd_pipe = StableDiffusionPipeline(
@@ -139,7 +123,8 @@ def test_save_pretrained_from_pretrained(self):
139123
text_encoder=bert,
140124
tokenizer=tokenizer,
141125
safety_checker=None,
142-
feature_extractor=feature_extractor,
126+
feature_extractor=None,
127+
requires_safety_checker=False,
143128
)
144129
sd_pipe = sd_pipe.to(device)
145130
sd_pipe.set_progress_bar_config(disable=None)
@@ -185,7 +170,8 @@ def test_stable_diffusion_ddim(self):
185170
text_encoder=bert,
186171
tokenizer=tokenizer,
187172
safety_checker=None,
188-
feature_extractor=self.dummy_extractor,
173+
feature_extractor=None,
174+
requires_safety_checker=False,
189175
)
190176
sd_pipe = sd_pipe.to(device)
191177
sd_pipe.set_progress_bar_config(disable=None)
@@ -231,7 +217,8 @@ def test_stable_diffusion_pndm(self):
231217
text_encoder=bert,
232218
tokenizer=tokenizer,
233219
safety_checker=None,
234-
feature_extractor=self.dummy_extractor,
220+
feature_extractor=None,
221+
requires_safety_checker=False,
235222
)
236223
sd_pipe = sd_pipe.to(device)
237224
sd_pipe.set_progress_bar_config(disable=None)
@@ -276,7 +263,8 @@ def test_stable_diffusion_k_lms(self):
276263
text_encoder=bert,
277264
tokenizer=tokenizer,
278265
safety_checker=None,
279-
feature_extractor=self.dummy_extractor,
266+
feature_extractor=None,
267+
requires_safety_checker=False,
280268
)
281269
sd_pipe = sd_pipe.to(device)
282270
sd_pipe.set_progress_bar_config(disable=None)
@@ -321,7 +309,8 @@ def test_stable_diffusion_k_euler_ancestral(self):
321309
text_encoder=bert,
322310
tokenizer=tokenizer,
323311
safety_checker=None,
324-
feature_extractor=self.dummy_extractor,
312+
feature_extractor=None,
313+
requires_safety_checker=False,
325314
)
326315
sd_pipe = sd_pipe.to(device)
327316
sd_pipe.set_progress_bar_config(disable=None)
@@ -366,7 +355,8 @@ def test_stable_diffusion_k_euler(self):
366355
text_encoder=bert,
367356
tokenizer=tokenizer,
368357
safety_checker=None,
369-
feature_extractor=self.dummy_extractor,
358+
feature_extractor=None,
359+
requires_safety_checker=False,
370360
)
371361
sd_pipe = sd_pipe.to(device)
372362
sd_pipe.set_progress_bar_config(disable=None)
@@ -411,7 +401,8 @@ def test_stable_diffusion_attention_chunk(self):
411401
text_encoder=bert,
412402
tokenizer=tokenizer,
413403
safety_checker=None,
414-
feature_extractor=self.dummy_extractor,
404+
feature_extractor=None,
405+
requires_safety_checker=False,
415406
)
416407
sd_pipe = sd_pipe.to(device)
417408
sd_pipe.set_progress_bar_config(disable=None)
@@ -449,7 +440,8 @@ def test_stable_diffusion_fp16(self):
449440
text_encoder=bert,
450441
tokenizer=tokenizer,
451442
safety_checker=None,
452-
feature_extractor=self.dummy_extractor,
443+
feature_extractor=None,
444+
requires_safety_checker=False,
453445
)
454446
sd_pipe = sd_pipe.to(torch_device)
455447
sd_pipe.set_progress_bar_config(disable=None)
@@ -475,7 +467,8 @@ def test_stable_diffusion_long_prompt(self):
475467
text_encoder=bert,
476468
tokenizer=tokenizer,
477469
safety_checker=None,
478-
feature_extractor=self.dummy_extractor,
470+
feature_extractor=None,
471+
requires_safety_checker=False,
479472
)
480473
sd_pipe = sd_pipe.to(torch_device)
481474
sd_pipe.set_progress_bar_config(disable=None)
@@ -572,7 +565,7 @@ def test_stable_diffusion_k_lms(self):
572565
expected_slice = np.array([0.0548, 0.0626, 0.0612, 0.0611, 0.0706, 0.0586, 0.0843, 0.0333, 0.1197])
573566
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
574567

575-
def test_stable_diffusion_memory_chunking(self):
568+
def test_stable_diffusion_attention_slicing(self):
576569
torch.cuda.reset_peak_memory_stats()
577570
model_id = "stabilityai/stable-diffusion-2-base"
578571
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
@@ -651,7 +644,7 @@ def test_stable_diffusion_text2img_pipeline_default(self):
651644
prompt = "astronaut riding a horse"
652645

653646
generator = torch.Generator(device=torch_device).manual_seed(0)
654-
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
647+
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
655648
image = output.images[0]
656649

657650
assert image.shape == (512, 512, 3)

0 commit comments

Comments
 (0)