Skip to content

Commit 0be2351

Browse files
committed
Merge branch 'resolution-checker' of https://github.com/blessedcoolant/stable-diffusion into main
2 parents a14fd69 + 1480ef8 commit 0be2351

File tree

2 files changed

+61
-29
lines changed

2 files changed

+61
-29
lines changed

ldm/dream/image_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def resize(self,width=None,height=None) -> Image:
5050
new_image = Image.new('RGB',(width,height))
5151
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
5252

53+
print(f'>> Resized image size to {width}x{height}')
54+
5355
return new_image
5456

5557
def make_grid(image_list, rows=None, cols=None):

ldm/simplet2i.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -266,16 +266,9 @@ def process_image(image,seed):
266266
assert (
267267
0.0 <= strength <= 1.0
268268
), 'can only work with strength in [0.0, 1.0]'
269-
w, h = map(
270-
lambda x: x - x % 64, (width, height)
271-
) # resize to integer multiple of 64
272269

273-
if h != height or w != width:
274-
print(
275-
f'Height and width must be multiples of 64. Resizing to {h}x{w}.'
276-
)
277-
height = h
278-
width = w
270+
if not(width == self.width and height == self.height):
271+
width, height, _ = self._resolution_check(width, height, log=True)
279272

280273
scope = autocast if self.precision == 'autocast' else nullcontext
281274

@@ -353,8 +346,11 @@ def process_image(image,seed):
353346
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
354347
)
355348
if image_callback is not None:
356-
image_callback(image, seed, upscaled=True)
357-
else: # no callback passed, so we simply replace old image with rescaled one
349+
if save_original:
350+
image_callback(image, seed)
351+
else:
352+
image_callback(image, seed, upscaled=True)
353+
else: # no callback passed, so we simply replace old image with rescaled one
358354
result[0] = image
359355

360356
except KeyboardInterrupt:
@@ -436,7 +432,7 @@ def _img2img(
436432
width,
437433
height,
438434
strength,
439-
callback, # Currently not implemented for img2img
435+
callback, # Currently not implemented for img2img
440436
):
441437
"""
442438
An infinite iterator of images from the prompt and the initial image
@@ -445,13 +441,13 @@ def _img2img(
445441
# PLMS sampler not supported yet, so ignore previous sampler
446442
if self.sampler_name != 'ddim':
447443
print(
448-
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
444+
f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
449445
)
450446
sampler = DDIMSampler(self.model, device=self.device)
451447
else:
452448
sampler = self.sampler
453449

454-
init_image = self._load_img(init_img,width,height).to(self.device)
450+
init_image = self._load_img(init_img, width, height).to(self.device)
455451
with precision_scope(self.device.type):
456452
init_latent = self.model.get_first_stage_encoding(
457453
self.model.encode_first_stage(init_image)
@@ -514,7 +510,8 @@ def _sample_to_image(self, samples):
514510
x_samples = self.model.decode_first_stage(samples)
515511
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
516512
if len(x_samples) != 1:
517-
raise Exception(f'expected to get a single image, but got {len(x_samples)}')
513+
raise Exception(
514+
f'expected to get a single image, but got {len(x_samples)}')
518515
x_sample = 255.0 * rearrange(
519516
x_samples[0].cpu().numpy(), 'c h w -> h w c'
520517
)
@@ -545,8 +542,9 @@ def load_model(self):
545542
self.model.cond_stage_model.device = self.device
546543
except AttributeError:
547544
import traceback
548-
print('Error loading model. Only the CUDA backend is supported',file=sys.stderr)
549-
print(traceback.format_exc(),file=sys.stderr)
545+
print(
546+
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
547+
print(traceback.format_exc(), file=sys.stderr)
550548
raise SystemExit
551549

552550
self._set_sampler()
@@ -606,10 +604,26 @@ def _load_img(self, path, width, height):
606604
print(f'image path = {path}, cwd = {os.getcwd()}')
607605
with Image.open(path) as img:
608606
image = img.convert('RGB')
609-
print(f'loaded input image of size {image.width}x{image.height} from {path}')
607+
print(
608+
f'loaded input image of size {image.width}x{image.height} from {path}')
610609

611-
image = InitImageResizer(image).resize(width,height)
612-
print(f'resized input image to size {image.width}x{image.height}')
610+
from ldm.dream.image_util import InitImageResizer
611+
if width == self.width and height == self.height:
612+
new_image_width, new_image_height, resize_needed = self._resolution_check(
613+
image.width, image.height)
614+
else:
615+
if height == self.height:
616+
new_image_width, new_image_height, resize_needed = self._resolution_check(
617+
width, image.height)
618+
if width == self.width:
619+
new_image_width, new_image_height, resize_needed = self._resolution_check(
620+
image.width, height)
621+
else:
622+
image = InitImageResizer(image).resize(width, height)
623+
resize_needed=False
624+
if resize_needed:
625+
image = InitImageResizer(image).resize(
626+
new_image_width, new_image_height)
613627

614628
image = np.array(image).astype(np.float32) / 255.0
615629
image = image[None].transpose(0, 3, 1, 2)
@@ -633,7 +647,7 @@ def _split_weighted_subprompts(text):
633647
prompt = text[:idx]
634648
remaining -= idx
635649
# remove from main text
636-
text = text[idx + 1 :]
650+
text = text[idx + 1:]
637651
# find value for weight
638652
if ' ' in text:
639653
idx = text.index(' ') # first occurence
@@ -651,7 +665,7 @@ def _split_weighted_subprompts(text):
651665
weight = 1.0
652666
# remove from main text
653667
remaining -= idx
654-
text = text[idx + 1 :]
668+
text = text[idx + 1:]
655669
# append the sub-prompt and its weight
656670
prompts.append(prompt)
657671
weights.append(weight)
@@ -662,9 +676,9 @@ def _split_weighted_subprompts(text):
662676
weights.append(1.0)
663677
remaining = 0
664678
return prompts, weights
665-
666-
# shows how the prompt is tokenized
667-
# usually tokens have '</w>' to indicate end-of-word,
679+
680+
# shows how the prompt is tokenized
681+
# usually tokens have '</w>' to indicate end-of-word,
668682
# but for readability it has been replaced with ' '
669683
def _log_tokenization(self, text):
670684
if not self.log_tokenization:
@@ -674,15 +688,31 @@ def _log_tokenization(self, text):
674688
discarded = ""
675689
usedTokens = 0
676690
totalTokens = len(tokens)
677-
for i in range(0,totalTokens):
678-
token = tokens[i].replace('</w>',' ')
691+
for i in range(0, totalTokens):
692+
token = tokens[i].replace('</w>', ' ')
679693
# alternate color
680694
s = (usedTokens % 6) + 1
681695
if i < self.model.cond_stage_model.max_length:
682696
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
683697
usedTokens += 1
684-
else: # over max token length
698+
else: # over max token length
685699
discarded = discarded + f"\x1b[0;3{s};40m{token}"
686700
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
687701
if discarded != "":
688-
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
702+
print(
703+
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
704+
705+
def _resolution_check(self, width, height, log=False):
706+
resize_needed = False
707+
w, h = map(
708+
lambda x: x - x % 64, (width, height)
709+
) # resize to integer multiple of 64
710+
if h != height or w != width:
711+
if log:
712+
print(
713+
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
714+
)
715+
height = h
716+
width = w
717+
resize_needed = True
718+
return width, height, resize_needed

0 commit comments

Comments
 (0)