34
34
)
35
35
from diffusers .utils import load_numpy , slow , torch_device
36
36
from diffusers .utils .testing_utils import CaptureLogger , require_torch_gpu
37
- from transformers import CLIPFeatureExtractor , CLIPTextConfig , CLIPTextModel , CLIPTokenizer
37
+ from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
38
38
39
39
from ...test_pipelines_common import PipelineTesterMixin
40
40
@@ -100,21 +100,6 @@ def dummy_text_encoder(self):
100
100
)
101
101
return CLIPTextModel (config )
102
102
103
- @property
104
- def dummy_extractor (self ):
105
- def extract (* args , ** kwargs ):
106
- class Out :
107
- def __init__ (self ):
108
- self .pixel_values = torch .ones ([0 ])
109
-
110
- def to (self , device ):
111
- self .pixel_values .to (device )
112
- return self
113
-
114
- return Out ()
115
-
116
- return extract
117
-
118
103
def test_save_pretrained_from_pretrained (self ):
119
104
device = "cpu" # ensure determinism for the device-dependent torch.Generator
120
105
unet = self .dummy_cond_unet
@@ -129,7 +114,6 @@ def test_save_pretrained_from_pretrained(self):
129
114
vae = self .dummy_vae
130
115
bert = self .dummy_text_encoder
131
116
tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
132
- feature_extractor = CLIPFeatureExtractor .from_pretrained ("hf-internal-testing/tiny-random-clip" )
133
117
134
118
# make sure here that pndm scheduler skips prk
135
119
sd_pipe = StableDiffusionPipeline (
@@ -139,7 +123,8 @@ def test_save_pretrained_from_pretrained(self):
139
123
text_encoder = bert ,
140
124
tokenizer = tokenizer ,
141
125
safety_checker = None ,
142
- feature_extractor = feature_extractor ,
126
+ feature_extractor = None ,
127
+ requires_safety_checker = False ,
143
128
)
144
129
sd_pipe = sd_pipe .to (device )
145
130
sd_pipe .set_progress_bar_config (disable = None )
@@ -185,7 +170,8 @@ def test_stable_diffusion_ddim(self):
185
170
text_encoder = bert ,
186
171
tokenizer = tokenizer ,
187
172
safety_checker = None ,
188
- feature_extractor = self .dummy_extractor ,
173
+ feature_extractor = None ,
174
+ requires_safety_checker = False ,
189
175
)
190
176
sd_pipe = sd_pipe .to (device )
191
177
sd_pipe .set_progress_bar_config (disable = None )
@@ -231,7 +217,8 @@ def test_stable_diffusion_pndm(self):
231
217
text_encoder = bert ,
232
218
tokenizer = tokenizer ,
233
219
safety_checker = None ,
234
- feature_extractor = self .dummy_extractor ,
220
+ feature_extractor = None ,
221
+ requires_safety_checker = False ,
235
222
)
236
223
sd_pipe = sd_pipe .to (device )
237
224
sd_pipe .set_progress_bar_config (disable = None )
@@ -276,7 +263,8 @@ def test_stable_diffusion_k_lms(self):
276
263
text_encoder = bert ,
277
264
tokenizer = tokenizer ,
278
265
safety_checker = None ,
279
- feature_extractor = self .dummy_extractor ,
266
+ feature_extractor = None ,
267
+ requires_safety_checker = False ,
280
268
)
281
269
sd_pipe = sd_pipe .to (device )
282
270
sd_pipe .set_progress_bar_config (disable = None )
@@ -321,7 +309,8 @@ def test_stable_diffusion_k_euler_ancestral(self):
321
309
text_encoder = bert ,
322
310
tokenizer = tokenizer ,
323
311
safety_checker = None ,
324
- feature_extractor = self .dummy_extractor ,
312
+ feature_extractor = None ,
313
+ requires_safety_checker = False ,
325
314
)
326
315
sd_pipe = sd_pipe .to (device )
327
316
sd_pipe .set_progress_bar_config (disable = None )
@@ -366,7 +355,8 @@ def test_stable_diffusion_k_euler(self):
366
355
text_encoder = bert ,
367
356
tokenizer = tokenizer ,
368
357
safety_checker = None ,
369
- feature_extractor = self .dummy_extractor ,
358
+ feature_extractor = None ,
359
+ requires_safety_checker = False ,
370
360
)
371
361
sd_pipe = sd_pipe .to (device )
372
362
sd_pipe .set_progress_bar_config (disable = None )
@@ -411,7 +401,8 @@ def test_stable_diffusion_attention_chunk(self):
411
401
text_encoder = bert ,
412
402
tokenizer = tokenizer ,
413
403
safety_checker = None ,
414
- feature_extractor = self .dummy_extractor ,
404
+ feature_extractor = None ,
405
+ requires_safety_checker = False ,
415
406
)
416
407
sd_pipe = sd_pipe .to (device )
417
408
sd_pipe .set_progress_bar_config (disable = None )
@@ -449,7 +440,8 @@ def test_stable_diffusion_fp16(self):
449
440
text_encoder = bert ,
450
441
tokenizer = tokenizer ,
451
442
safety_checker = None ,
452
- feature_extractor = self .dummy_extractor ,
443
+ feature_extractor = None ,
444
+ requires_safety_checker = False ,
453
445
)
454
446
sd_pipe = sd_pipe .to (torch_device )
455
447
sd_pipe .set_progress_bar_config (disable = None )
@@ -475,7 +467,8 @@ def test_stable_diffusion_long_prompt(self):
475
467
text_encoder = bert ,
476
468
tokenizer = tokenizer ,
477
469
safety_checker = None ,
478
- feature_extractor = self .dummy_extractor ,
470
+ feature_extractor = None ,
471
+ requires_safety_checker = False ,
479
472
)
480
473
sd_pipe = sd_pipe .to (torch_device )
481
474
sd_pipe .set_progress_bar_config (disable = None )
@@ -572,7 +565,7 @@ def test_stable_diffusion_k_lms(self):
572
565
expected_slice = np .array ([0.0548 , 0.0626 , 0.0612 , 0.0611 , 0.0706 , 0.0586 , 0.0843 , 0.0333 , 0.1197 ])
573
566
assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
574
567
575
- def test_stable_diffusion_memory_chunking (self ):
568
+ def test_stable_diffusion_attention_slicing (self ):
576
569
torch .cuda .reset_peak_memory_stats ()
577
570
model_id = "stabilityai/stable-diffusion-2-base"
578
571
pipe = StableDiffusionPipeline .from_pretrained (model_id , revision = "fp16" , torch_dtype = torch .float16 )
@@ -651,7 +644,7 @@ def test_stable_diffusion_text2img_pipeline_default(self):
651
644
prompt = "astronaut riding a horse"
652
645
653
646
generator = torch .Generator (device = torch_device ).manual_seed (0 )
654
- output = pipe (prompt = prompt , strength = 0.75 , guidance_scale = 7.5 , generator = generator , output_type = "np" )
647
+ output = pipe (prompt = prompt , guidance_scale = 7.5 , generator = generator , output_type = "np" )
655
648
image = output .images [0 ]
656
649
657
650
assert image .shape == (512 , 512 , 3 )
0 commit comments