Skip to content

Commit e0e86b7

Browse files
Make height and width optional (#1401)
* fix * add test * fix test * uP * up * fix some tests
1 parent 81d8f4a commit e0e86b7

20 files changed

+176
-95
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
390390
def __call__(
391391
self,
392392
prompt: Union[str, List[str]],
393-
height: int = 512,
394-
width: int = 512,
393+
height: Optional[int] = None,
394+
width: Optional[int] = None,
395395
num_inference_steps: int = 50,
396396
guidance_scale: float = 7.5,
397397
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -411,9 +411,9 @@ def __call__(
411411
Args:
412412
prompt (`str` or `List[str]`):
413413
The prompt or prompts to guide the image generation.
414-
height (`int`, *optional*, defaults to 512):
414+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
415415
The height in pixels of the generated image.
416-
width (`int`, *optional*, defaults to 512):
416+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
417417
The width in pixels of the generated image.
418418
num_inference_steps (`int`, *optional*, defaults to 50):
419419
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -459,6 +459,9 @@ def __call__(
459459
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
460460
(nsfw) content, according to the `safety_checker`.
461461
"""
462+
# 0. Default height and width to unet
463+
height = height or self.unet.config.sample_size * 8
464+
width = width or self.unet.config.sample_size * 8
462465

463466
# 1. Check inputs. Raise error if not correct
464467
self.check_inputs(prompt, height, width, callback_steps)

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __init__(
6565
def __call__(
6666
self,
6767
prompt: Union[str, List[str]],
68-
height: Optional[int] = 256,
69-
width: Optional[int] = 256,
68+
height: Optional[int] = None,
69+
width: Optional[int] = None,
7070
num_inference_steps: Optional[int] = 50,
7171
guidance_scale: Optional[float] = 1.0,
7272
eta: Optional[float] = 0.0,
@@ -79,9 +79,9 @@ def __call__(
7979
Args:
8080
prompt (`str` or `List[str]`):
8181
The prompt or prompts to guide the image generation.
82-
height (`int`, *optional*, defaults to 256):
82+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
8383
The height in pixels of the generated image.
84-
width (`int`, *optional*, defaults to 256):
84+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
8585
The width in pixels of the generated image.
8686
num_inference_steps (`int`, *optional*, defaults to 50):
8787
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -106,6 +106,9 @@ def __call__(
106106
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
107107
generated images.
108108
"""
109+
# 0. Default height and width to unet
110+
height = height or self.unet.config.sample_size * 8
111+
width = width or self.unet.config.sample_size * 8
109112

110113
if isinstance(prompt, str):
111114
batch_size = 1

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,17 @@ def _generate(
160160
params: Union[Dict, FrozenDict],
161161
prng_seed: jax.random.PRNGKey,
162162
num_inference_steps: int = 50,
163-
height: int = 512,
164-
width: int = 512,
163+
height: Optional[int] = None,
164+
width: Optional[int] = None,
165165
guidance_scale: float = 7.5,
166166
latents: Optional[jnp.array] = None,
167167
debug: bool = False,
168168
neg_prompt_ids: jnp.array = None,
169169
):
170+
# 0. Default height and width to unet
171+
height = height or self.unet.config.sample_size * 8
172+
width = width or self.unet.config.sample_size * 8
173+
170174
if height % 8 != 0 or width % 8 != 0:
171175
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
172176

@@ -249,8 +253,8 @@ def __call__(
249253
params: Union[Dict, FrozenDict],
250254
prng_seed: jax.random.PRNGKey,
251255
num_inference_steps: int = 50,
252-
height: int = 512,
253-
width: int = 512,
256+
height: Optional[int] = None,
257+
width: Optional[int] = None,
254258
guidance_scale: float = 7.5,
255259
latents: jnp.array = None,
256260
return_dict: bool = True,
@@ -265,9 +269,9 @@ def __call__(
265269
Args:
266270
prompt (`str` or `List[str]`):
267271
The prompt or prompts to guide the image generation.
268-
height (`int`, *optional*, defaults to 512):
272+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
269273
The height in pixels of the generated image.
270-
width (`int`, *optional*, defaults to 512):
274+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
271275
The width in pixels of the generated image.
272276
num_inference_steps (`int`, *optional*, defaults to 50):
273277
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -302,6 +306,10 @@ def __call__(
302306
element is a list of `bool`s denoting whether the corresponding generated image likely represents
303307
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
304308
"""
309+
# 0. Default height and width to unet
310+
height = height or self.unet.config.sample_size * 8
311+
width = width or self.unet.config.sample_size * 8
312+
305313
if jit:
306314
images = _p_generate(
307315
self,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
172172
def __call__(
173173
self,
174174
prompt: Union[str, List[str]],
175-
height: Optional[int] = 512,
176-
width: Optional[int] = 512,
175+
height: Optional[int] = None,
176+
width: Optional[int] = None,
177177
num_inference_steps: Optional[int] = 50,
178178
guidance_scale: Optional[float] = 7.5,
179179
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -187,6 +187,10 @@ def __call__(
187187
callback_steps: Optional[int] = 1,
188188
**kwargs,
189189
):
190+
# 0. Default height and width to unet
191+
height = height or self.unet.config.sample_size * 8
192+
width = width or self.unet.config.sample_size * 8
193+
190194
if isinstance(prompt, str):
191195
batch_size = 1
192196
elif isinstance(prompt, list):

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ def __call__(
236236
prompt: Union[str, List[str]],
237237
image: PIL.Image.Image,
238238
mask_image: PIL.Image.Image,
239-
height: int = 512,
240-
width: int = 512,
239+
height: Optional[int] = None,
240+
width: Optional[int] = None,
241241
num_inference_steps: int = 50,
242242
guidance_scale: float = 7.5,
243243
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -265,9 +265,9 @@ def __call__(
265265
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
266266
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
267267
instead of 3, so the expected shape would be `(B, H, W, 1)`.
268-
height (`int`, *optional*, defaults to 512):
268+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
269269
The height in pixels of the generated image.
270-
width (`int`, *optional*, defaults to 512):
270+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
271271
The width in pixels of the generated image.
272272
num_inference_steps (`int`, *optional*, defaults to 50):
273273
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -312,6 +312,10 @@ def __call__(
312312
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
313313
(nsfw) content, according to the `safety_checker`.
314314
"""
315+
# 0. Default height and width to unet
316+
height = height or self.unet.config.sample_size * 8
317+
width = width or self.unet.config.sample_size * 8
318+
315319
if isinstance(prompt, str):
316320
batch_size = 1
317321
elif isinstance(prompt, list):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
389389
def __call__(
390390
self,
391391
prompt: Union[str, List[str]],
392-
height: int = 512,
393-
width: int = 512,
392+
height: Optional[int] = None,
393+
width: Optional[int] = None,
394394
num_inference_steps: int = 50,
395395
guidance_scale: float = 7.5,
396396
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -410,9 +410,9 @@ def __call__(
410410
Args:
411411
prompt (`str` or `List[str]`):
412412
The prompt or prompts to guide the image generation.
413-
height (`int`, *optional*, defaults to 512):
413+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
414414
The height in pixels of the generated image.
415-
width (`int`, *optional*, defaults to 512):
415+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
416416
The width in pixels of the generated image.
417417
num_inference_steps (`int`, *optional*, defaults to 50):
418418
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -458,6 +458,9 @@ def __call__(
458458
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
459459
(nsfw) content, according to the `safety_checker`.
460460
"""
461+
# 0. Default height and width to unet
462+
height = height or self.unet.config.sample_size * 8
463+
width = width or self.unet.config.sample_size * 8
461464

462465
# 1. Check inputs. Raise error if not correct
463466
self.check_inputs(prompt, height, width, callback_steps)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
292292
def __call__(
293293
self,
294294
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
295-
height: int = 512,
296-
width: int = 512,
295+
height: Optional[int] = None,
296+
width: Optional[int] = None,
297297
num_inference_steps: int = 50,
298298
guidance_scale: float = 7.5,
299299
num_images_per_prompt: Optional[int] = 1,
@@ -315,9 +315,9 @@ def __call__(
315315
configuration of
316316
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
317317
`CLIPFeatureExtractor`
318-
height (`int`, *optional*, defaults to 512):
318+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
319319
The height in pixels of the generated image.
320-
width (`int`, *optional*, defaults to 512):
320+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
321321
The width in pixels of the generated image.
322322
num_inference_steps (`int`, *optional*, defaults to 50):
323323
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -360,6 +360,9 @@ def __call__(
360360
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
361361
(nsfw) content, according to the `safety_checker`.
362362
"""
363+
# 0. Default height and width to unet
364+
height = height or self.unet.config.sample_size * 8
365+
width = width or self.unet.config.sample_size * 8
363366

364367
# 1. Check inputs. Raise error if not correct
365368
self.check_inputs(image, height, width, callback_steps)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ def __call__(
509509
prompt: Union[str, List[str]],
510510
image: Union[torch.FloatTensor, PIL.Image.Image],
511511
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
512-
height: int = 512,
513-
width: int = 512,
512+
height: Optional[int] = None,
513+
width: Optional[int] = None,
514514
num_inference_steps: int = 50,
515515
guidance_scale: float = 7.5,
516516
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -538,9 +538,9 @@ def __call__(
538538
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
539539
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
540540
instead of 3, so the expected shape would be `(B, H, W, 1)`.
541-
height (`int`, *optional*, defaults to 512):
541+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
542542
The height in pixels of the generated image.
543-
width (`int`, *optional*, defaults to 512):
543+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
544544
The width in pixels of the generated image.
545545
num_inference_steps (`int`, *optional*, defaults to 50):
546546
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -586,6 +586,9 @@ def __call__(
586586
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
587587
(nsfw) content, according to the `safety_checker`.
588588
"""
589+
# 0. Default height and width to unet
590+
height = height or self.unet.config.sample_size * 8
591+
width = width or self.unet.config.sample_size * 8
589592

590593
# 1. Check inputs
591594
self.check_inputs(prompt, height, width, callback_steps)

src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,8 @@ def perform_safety_guidance(
495495
def __call__(
496496
self,
497497
prompt: Union[str, List[str]],
498-
height: int = 512,
499-
width: int = 512,
498+
height: Optional[int] = None,
499+
width: Optional[int] = None,
500500
num_inference_steps: int = 50,
501501
guidance_scale: float = 7.5,
502502
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -521,9 +521,9 @@ def __call__(
521521
Args:
522522
prompt (`str` or `List[str]`):
523523
The prompt or prompts to guide the image generation.
524-
height (`int`, *optional*, defaults to 512):
524+
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
525525
The height in pixels of the generated image.
526-
width (`int`, *optional*, defaults to 512):
526+
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
527527
The width in pixels of the generated image.
528528
num_inference_steps (`int`, *optional*, defaults to 50):
529529
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -589,6 +589,9 @@ def __call__(
589589
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
590590
(nsfw) content, according to the `safety_checker`.
591591
"""
592+
# 0. Default height and width to unet
593+
height = height or self.unet.config.sample_size * 8
594+
width = width or self.unet.config.sample_size * 8
592595

593596
# 1. Check inputs. Raise error if not correct
594597
self.check_inputs(prompt, height, width, callback_steps)

src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def disable_attention_slicing(self):
111111
def image_variation(
112112
self,
113113
image: Union[torch.FloatTensor, PIL.Image.Image],
114-
height: int = 512,
115-
width: int = 512,
114+
height: Optional[int] = None,
115+
width: Optional[int] = None,
116116
num_inference_steps: int = 50,
117117
guidance_scale: float = 7.5,
118118
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -131,9 +131,9 @@ def image_variation(
131131
Args:
132132
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
133133
The image prompt or prompts to guide the image generation.
134-
height (`int`, *optional*, defaults to 512):
134+
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
135135
The height in pixels of the generated image.
136-
width (`int`, *optional*, defaults to 512):
136+
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
137137
The width in pixels of the generated image.
138138
num_inference_steps (`int`, *optional*, defaults to 50):
139139
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -193,7 +193,7 @@ def image_variation(
193193
>>> pipe = pipe.to("cuda")
194194
195195
>>> generator = torch.Generator(device="cuda").manual_seed(0)
196-
>>> image = pipe(image, generator=generator).images[0]
196+
>>> image = pipe.image_variation(image, generator=generator).images[0]
197197
>>> image.save("./car_variation.png")
198198
```
199199
@@ -227,8 +227,8 @@ def image_variation(
227227
def text_to_image(
228228
self,
229229
prompt: Union[str, List[str]],
230-
height: int = 512,
231-
width: int = 512,
230+
height: Optional[int] = None,
231+
width: Optional[int] = None,
232232
num_inference_steps: int = 50,
233233
guidance_scale: float = 7.5,
234234
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -247,9 +247,9 @@ def text_to_image(
247247
Args:
248248
prompt (`str` or `List[str]`):
249249
The prompt or prompts to guide the image generation.
250-
height (`int`, *optional*, defaults to 512):
250+
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
251251
The height in pixels of the generated image.
252-
width (`int`, *optional*, defaults to 512):
252+
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
253253
The width in pixels of the generated image.
254254
num_inference_steps (`int`, *optional*, defaults to 50):
255255
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -341,8 +341,8 @@ def dual_guided(
341341
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
342342
image: Union[str, List[str]],
343343
text_to_image_strength: float = 0.5,
344-
height: int = 512,
345-
width: int = 512,
344+
height: Optional[int] = None,
345+
width: Optional[int] = None,
346346
num_inference_steps: int = 50,
347347
guidance_scale: float = 7.5,
348348
num_images_per_prompt: Optional[int] = 1,
@@ -360,9 +360,9 @@ def dual_guided(
360360
Args:
361361
prompt (`str` or `List[str]`):
362362
The prompt or prompts to guide the image generation.
363-
height (`int`, *optional*, defaults to 512):
363+
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
364364
The height in pixels of the generated image.
365-
width (`int`, *optional*, defaults to 512):
365+
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * 8):
366366
The width in pixels of the generated image.
367367
num_inference_steps (`int`, *optional*, defaults to 50):
368368
The number of denoising steps. More denoising steps usually lead to a higher quality image at the

0 commit comments

Comments
 (0)