Skip to content

Commit 05a36d5

Browse files
Upscaling fixed (#1402)
* Upscaling fixed * up * more fixes * fix * more fixes * finish again * up
1 parent cbfed0c commit 05a36d5

26 files changed

+226
-120
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
544544
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
545545

546546
# remove `null` components
547-
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
547+
def load_module(name, value):
548+
if value[0] is None:
549+
return False
550+
if name in passed_class_obj and passed_class_obj[name] is None:
551+
return False
552+
return True
553+
554+
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
548555

549556
if len(unused_kwargs) > 0:
550557
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
@@ -560,12 +567,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
560567

561568
is_pipeline_module = hasattr(pipelines, library_name)
562569
loaded_sub_model = None
563-
sub_model_should_be_defined = True
564570

565571
# if the model is in a pipeline module, then we load it from the pipeline
566572
if name in passed_class_obj:
567573
# 1. check that passed_class_obj has correct parent class
568-
if not is_pipeline_module and passed_class_obj[name] is not None:
574+
if not is_pipeline_module:
569575
library = importlib.import_module(library_name)
570576
class_obj = getattr(library, class_name)
571577
importable_classes = LOADABLE_CLASSES[library_name]
@@ -581,12 +587,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
581587
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
582588
f" {expected_class_obj}"
583589
)
584-
elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
585-
logger.warning(
586-
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
587-
f" that this might lead to problems when using {pipeline_class} and is not recommended."
588-
)
589-
sub_model_should_be_defined = False
590590
else:
591591
logger.warning(
592592
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -608,7 +608,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
608608
importable_classes = LOADABLE_CLASSES[library_name]
609609
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
610610

611-
if loaded_sub_model is None and sub_model_should_be_defined:
611+
if loaded_sub_model is None:
612612
load_method_name = None
613613
for class_name, class_candidate in class_candidates.items():
614614
if class_candidate is not None and issubclass(class_obj, class_candidate):

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def __init__(
141141
safety_checker=safety_checker,
142142
feature_extractor=feature_extractor,
143143
)
144+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
144145
self.register_to_config(requires_safety_checker=requires_safety_checker)
145146

146147
def enable_xformers_memory_efficient_attention(self):
@@ -379,7 +380,7 @@ def check_inputs(self, prompt, height, width, callback_steps):
379380
)
380381

381382
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
382-
shape = (batch_size, num_channels_latents, height // 8, width // 8)
383+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
383384
if latents is None:
384385
if device.type == "mps":
385386
# randn does not work reproducibly on mps
@@ -420,9 +421,9 @@ def __call__(
420421
Args:
421422
prompt (`str` or `List[str]`):
422423
The prompt or prompts to guide the image generation.
423-
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
424+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
424425
The height in pixels of the generated image.
425-
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
426+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
426427
The width in pixels of the generated image.
427428
num_inference_steps (`int`, *optional*, defaults to 50):
428429
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -469,8 +470,8 @@ def __call__(
469470
(nsfw) content, according to the `safety_checker`.
470471
"""
471472
# 0. Default height and width to unet
472-
height = height or self.unet.config.sample_size * 8
473-
width = width or self.unet.config.sample_size * 8
473+
height = height or self.unet.config.sample_size * self.vae_scale_factor
474+
width = width or self.unet.config.sample_size * self.vae_scale_factor
474475

475476
# 1. Check inputs. Raise error if not correct
476477
self.check_inputs(prompt, height, width, callback_steps)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
safety_checker=safety_checker,
155155
feature_extractor=feature_extractor,
156156
)
157+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
157158
self.register_to_config(requires_safety_checker=requires_safety_checker)
158159

159160
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
):
6161
super().__init__()
6262
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
63+
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
6364

6465
@torch.no_grad()
6566
def __call__(
@@ -79,9 +80,9 @@ def __call__(
7980
Args:
8081
prompt (`str` or `List[str]`):
8182
The prompt or prompts to guide the image generation.
82-
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
83+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
8384
The height in pixels of the generated image.
84-
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
85+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
8586
The width in pixels of the generated image.
8687
num_inference_steps (`int`, *optional*, defaults to 50):
8788
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -107,8 +108,8 @@ def __call__(
107108
generated images.
108109
"""
109110
# 0. Default height and width to unet
110-
height = height or self.unet.config.sample_size * 8
111-
width = width or self.unet.config.sample_size * 8
111+
height = height or self.unet.config.sample_size * self.vae_scale_factor
112+
width = width or self.unet.config.sample_size * self.vae_scale_factor
112113

113114
if isinstance(prompt, str):
114115
batch_size = 1

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
safety_checker=safety_checker,
107107
feature_extractor=feature_extractor,
108108
)
109+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
109110

110111
def prepare_inputs(self, prompt: Union[str, List[str]]):
111112
if not isinstance(prompt, (str, list)):
@@ -168,8 +169,8 @@ def _generate(
168169
neg_prompt_ids: jnp.array = None,
169170
):
170171
# 0. Default height and width to unet
171-
height = height or self.unet.config.sample_size * 8
172-
width = width or self.unet.config.sample_size * 8
172+
height = height or self.unet.config.sample_size * self.vae_scale_factor
173+
width = width or self.unet.config.sample_size * self.vae_scale_factor
173174

174175
if height % 8 != 0 or width % 8 != 0:
175176
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -192,7 +193,12 @@ def _generate(
192193
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
193194
context = jnp.concatenate([uncond_embeddings, text_embeddings])
194195

195-
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
196+
latents_shape = (
197+
batch_size,
198+
self.unet.in_channels,
199+
height // self.vae_scale_factor,
200+
width // self.vae_scale_factor,
201+
)
196202
if latents is None:
197203
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
198204
else:
@@ -269,9 +275,9 @@ def __call__(
269275
Args:
270276
prompt (`str` or `List[str]`):
271277
The prompt or prompts to guide the image generation.
272-
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
278+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
273279
The height in pixels of the generated image.
274-
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
280+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
275281
The width in pixels of the generated image.
276282
num_inference_steps (`int`, *optional*, defaults to 50):
277283
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -307,8 +313,8 @@ def __call__(
307313
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
308314
"""
309315
# 0. Default height and width to unet
310-
height = height or self.unet.config.sample_size * 8
311-
width = width or self.unet.config.sample_size * 8
316+
height = height or self.unet.config.sample_size * self.vae_scale_factor
317+
width = width or self.unet.config.sample_size * self.vae_scale_factor
312318

313319
if jit:
314320
images = _p_generate(

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
safety_checker=safety_checker,
109109
feature_extractor=feature_extractor,
110110
)
111+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
111112
self.register_to_config(requires_safety_checker=requires_safety_checker)
112113

113114
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -206,8 +207,8 @@ def __call__(
206207
**kwargs,
207208
):
208209
# 0. Default height and width to unet
209-
height = height or self.unet.config.sample_size * 8
210-
width = width or self.unet.config.sample_size * 8
210+
height = height or self.unet.config.sample_size * self.vae_scale_factor
211+
width = width or self.unet.config.sample_size * self.vae_scale_factor
211212

212213
if isinstance(prompt, str):
213214
batch_size = 1
@@ -241,7 +242,12 @@ def __call__(
241242

242243
# get the initial random noise unless the user supplied it
243244
latents_dtype = text_embeddings.dtype
244-
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
245+
latents_shape = (
246+
batch_size * num_images_per_prompt,
247+
4,
248+
height // self.vae_scale_factor,
249+
width // self.vae_scale_factor,
250+
)
245251
if latents is None:
246252
latents = generator.randn(*latents_shape).astype(latents_dtype)
247253
elif latents.shape != latents_shape:

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
safety_checker=safety_checker,
159159
feature_extractor=feature_extractor,
160160
)
161+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
161162
self.register_to_config(requires_safety_checker=requires_safety_checker)
162163

163164
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -273,9 +274,9 @@ def __call__(
273274
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
274275
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
275276
instead of 3, so the expected shape would be `(B, H, W, 1)`.
276-
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
277+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
277278
The height in pixels of the generated image.
278-
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
279+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
279280
The width in pixels of the generated image.
280281
num_inference_steps (`int`, *optional*, defaults to 50):
281282
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -321,8 +322,8 @@ def __call__(
321322
(nsfw) content, according to the `safety_checker`.
322323
"""
323324
# 0. Default height and width to unet
324-
height = height or self.unet.config.sample_size * 8
325-
width = width or self.unet.config.sample_size * 8
325+
height = height or self.unet.config.sample_size * self.vae_scale_factor
326+
width = width or self.unet.config.sample_size * self.vae_scale_factor
326327

327328
if isinstance(prompt, str):
328329
batch_size = 1
@@ -358,7 +359,12 @@ def __call__(
358359
)
359360

360361
num_channels_latents = NUM_LATENT_CHANNELS
361-
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
362+
latents_shape = (
363+
batch_size * num_images_per_prompt,
364+
num_channels_latents,
365+
height // self.vae_scale_factor,
366+
width // self.vae_scale_factor,
367+
)
362368
latents_dtype = text_embeddings.dtype
363369
if latents is None:
364370
latents = generator.randn(*latents_shape).astype(latents_dtype)

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def preprocess(image):
2727
return 2.0 * image - 1.0
2828

2929

30-
def preprocess_mask(mask):
30+
def preprocess_mask(mask, scale_factor=8):
3131
mask = mask.convert("L")
3232
w, h = mask.size
3333
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
34-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
34+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
3535
mask = np.array(mask).astype(np.float32) / 255.0
3636
mask = np.tile(mask, (4, 1, 1))
3737
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -143,6 +143,7 @@ def __init__(
143143
safety_checker=safety_checker,
144144
feature_extractor=feature_extractor,
145145
)
146+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
146147
self.register_to_config(requires_safety_checker=requires_safety_checker)
147148

148149
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -349,7 +350,7 @@ def __call__(
349350

350351
# preprocess mask
351352
if not isinstance(mask_image, np.ndarray):
352-
mask_image = preprocess_mask(mask_image)
353+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
353354
mask_image = mask_image.astype(latents_dtype)
354355
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
355356

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
safety_checker=safety_checker,
141141
feature_extractor=feature_extractor,
142142
)
143+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
143144
self.register_to_config(requires_safety_checker=requires_safety_checker)
144145

145146
def enable_xformers_memory_efficient_attention(self):
@@ -378,7 +379,7 @@ def check_inputs(self, prompt, height, width, callback_steps):
378379
)
379380

380381
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
381-
shape = (batch_size, num_channels_latents, height // 8, width // 8)
382+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
382383
if latents is None:
383384
if device.type == "mps":
384385
# randn does not work reproducibly on mps
@@ -419,9 +420,9 @@ def __call__(
419420
Args:
420421
prompt (`str` or `List[str]`):
421422
The prompt or prompts to guide the image generation.
422-
height (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
423+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
423424
The height in pixels of the generated image.
424-
width (`int`, *optional*, defaults to self.unet.config.sample_size * 8):
425+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
425426
The width in pixels of the generated image.
426427
num_inference_steps (`int`, *optional*, defaults to 50):
427428
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -468,8 +469,8 @@ def __call__(
468469
(nsfw) content, according to the `safety_checker`.
469470
"""
470471
# 0. Default height and width to unet
471-
height = height or self.unet.config.sample_size * 8
472-
width = width or self.unet.config.sample_size * 8
472+
height = height or self.unet.config.sample_size * self.vae_scale_factor
473+
width = width or self.unet.config.sample_size * self.vae_scale_factor
473474

474475
# 1. Check inputs. Raise error if not correct
475476
self.check_inputs(prompt, height, width, callback_steps)

0 commit comments

Comments
 (0)