Skip to content

Commit 1480ef8

Browse files
Add Resolution Checker
1 parent 1714816 commit 1480ef8

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

ldm/dream/image_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def resize(self,width=None,height=None) -> Image:
4949
new_image = Image.new('RGB',(width,height))
5050
new_image.paste(resized_image,((width-rw)//2,(height-rh)//2))
5151

52+
print(f'>> Resized image size to {width}x{height}')
53+
5254
return new_image
5355

5456

ldm/simplet2i.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ldm.models.diffusion.plms import PLMSSampler
2828
from ldm.models.diffusion.ksampler import KSampler
2929
from ldm.dream.pngwriter import PngWriter
30-
from ldm.dream.image_util import InitImageResizer
3130

3231
"""Simplified text to image API for stable diffusion/latent diffusion
3332
@@ -261,16 +260,9 @@ def process_image(image,seed):
261260
assert (
262261
0.0 <= strength <= 1.0
263262
), 'can only work with strength in [0.0, 1.0]'
264-
w, h = map(
265-
lambda x: x - x % 64, (width, height)
266-
) # resize to integer multiple of 64
267263

268-
if h != height or w != width:
269-
print(
270-
f'Height and width must be multiples of 64. Resizing to {h}x{w}.'
271-
)
272-
height = h
273-
width = w
264+
if not(width == self.width and height == self.height):
265+
width, height, _ = self._resolution_check(width, height, log=True)
274266

275267
scope = autocast if self.precision == 'autocast' else nullcontext
276268

@@ -352,7 +344,7 @@ def process_image(image,seed):
352344
image_callback(image, seed)
353345
else:
354346
image_callback(image, seed, upscaled=True)
355-
else: # no callback passed, so we simply replace old image with rescaled one
347+
else: # no callback passed, so we simply replace old image with rescaled one
356348
result[0] = image
357349

358350
except KeyboardInterrupt:
@@ -434,7 +426,7 @@ def _img2img(
434426
width,
435427
height,
436428
strength,
437-
callback, # Currently not implemented for img2img
429+
callback, # Currently not implemented for img2img
438430
):
439431
"""
440432
An infinite iterator of images from the prompt and the initial image
@@ -443,13 +435,13 @@ def _img2img(
443435
# PLMS sampler not supported yet, so ignore previous sampler
444436
if self.sampler_name != 'ddim':
445437
print(
446-
f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler"
438+
f"sampler '{self.sampler_name}' is not yet supported. Using DDIM sampler"
447439
)
448440
sampler = DDIMSampler(self.model, device=self.device)
449441
else:
450442
sampler = self.sampler
451443

452-
init_image = self._load_img(init_img,width,height).to(self.device)
444+
init_image = self._load_img(init_img, width, height).to(self.device)
453445
with precision_scope(self.device.type):
454446
init_latent = self.model.get_first_stage_encoding(
455447
self.model.encode_first_stage(init_image)
@@ -512,7 +504,8 @@ def _sample_to_image(self, samples):
512504
x_samples = self.model.decode_first_stage(samples)
513505
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
514506
if len(x_samples) != 1:
515-
raise Exception(f'expected to get a single image, but got {len(x_samples)}')
507+
raise Exception(
508+
f'expected to get a single image, but got {len(x_samples)}')
516509
x_sample = 255.0 * rearrange(
517510
x_samples[0].cpu().numpy(), 'c h w -> h w c'
518511
)
@@ -547,8 +540,9 @@ def load_model(self):
547540
self.model.cond_stage_model.device = self.device
548541
except AttributeError:
549542
import traceback
550-
print('Error loading model. Only the CUDA backend is supported',file=sys.stderr)
551-
print(traceback.format_exc(),file=sys.stderr)
543+
print(
544+
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
545+
print(traceback.format_exc(), file=sys.stderr)
552546
raise SystemExit
553547

554548
self._set_sampler()
@@ -608,10 +602,26 @@ def _load_img(self, path, width, height):
608602
print(f'image path = {path}, cwd = {os.getcwd()}')
609603
with Image.open(path) as img:
610604
image = img.convert('RGB')
611-
print(f'loaded input image of size {image.width}x{image.height} from {path}')
605+
print(
606+
f'loaded input image of size {image.width}x{image.height} from {path}')
612607

613-
image = InitImageResizer(image).resize(width,height)
614-
print(f'resized input image to size {image.width}x{image.height}')
608+
from ldm.dream.image_util import InitImageResizer
609+
if width == self.width and height == self.height:
610+
new_image_width, new_image_height, resize_needed = self._resolution_check(
611+
image.width, image.height)
612+
else:
613+
if height == self.height:
614+
new_image_width, new_image_height, resize_needed = self._resolution_check(
615+
width, image.height)
616+
if width == self.width:
617+
new_image_width, new_image_height, resize_needed = self._resolution_check(
618+
image.width, height)
619+
else:
620+
image = InitImageResizer(image).resize(width, height)
621+
resize_needed=False
622+
if resize_needed:
623+
image = InitImageResizer(image).resize(
624+
new_image_width, new_image_height)
615625

616626
image = np.array(image).astype(np.float32) / 255.0
617627
image = image[None].transpose(0, 3, 1, 2)
@@ -635,7 +645,7 @@ def _split_weighted_subprompts(text):
635645
prompt = text[:idx]
636646
remaining -= idx
637647
# remove from main text
638-
text = text[idx + 1 :]
648+
text = text[idx + 1:]
639649
# find value for weight
640650
if ' ' in text:
641651
idx = text.index(' ') # first occurence
@@ -653,7 +663,7 @@ def _split_weighted_subprompts(text):
653663
weight = 1.0
654664
# remove from main text
655665
remaining -= idx
656-
text = text[idx + 1 :]
666+
text = text[idx + 1:]
657667
# append the sub-prompt and its weight
658668
prompts.append(prompt)
659669
weights.append(weight)
@@ -664,9 +674,9 @@ def _split_weighted_subprompts(text):
664674
weights.append(1.0)
665675
remaining = 0
666676
return prompts, weights
667-
668-
# shows how the prompt is tokenized
669-
# usually tokens have '</w>' to indicate end-of-word,
677+
678+
# shows how the prompt is tokenized
679+
# usually tokens have '</w>' to indicate end-of-word,
670680
# but for readability it has been replaced with ' '
671681
def _log_tokenization(self, text):
672682
if not self.log_tokenization:
@@ -676,15 +686,31 @@ def _log_tokenization(self, text):
676686
discarded = ""
677687
usedTokens = 0
678688
totalTokens = len(tokens)
679-
for i in range(0,totalTokens):
680-
token = tokens[i].replace('</w>',' ')
689+
for i in range(0, totalTokens):
690+
token = tokens[i].replace('</w>', ' ')
681691
# alternate color
682692
s = (usedTokens % 6) + 1
683693
if i < self.model.cond_stage_model.max_length:
684694
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
685695
usedTokens += 1
686-
else: # over max token length
696+
else: # over max token length
687697
discarded = discarded + f"\x1b[0;3{s};40m{token}"
688698
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
689699
if discarded != "":
690-
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
700+
print(
701+
f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
702+
703+
def _resolution_check(self, width, height, log=False):
704+
resize_needed = False
705+
w, h = map(
706+
lambda x: x - x % 64, (width, height)
707+
) # resize to integer multiple of 64
708+
if h != height or w != width:
709+
if log:
710+
print(
711+
f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
712+
)
713+
height = h
714+
width = w
715+
resize_needed = True
716+
return width, height, resize_needed

0 commit comments

Comments
 (0)