@@ -266,16 +266,9 @@ def process_image(image,seed):
266266 assert (
267267 0.0 <= strength <= 1.0
268268 ), 'can only work with strength in [0.0, 1.0]'
269- w , h = map (
270- lambda x : x - x % 64 , (width , height )
271- ) # resize to integer multiple of 64
272269
273- if h != height or w != width :
274- print (
275- f'Height and width must be multiples of 64. Resizing to { h } x{ w } .'
276- )
277- height = h
278- width = w
270+ if not (width == self .width and height == self .height ):
271+ width , height , _ = self ._resolution_check (width , height , log = True )
279272
280273 scope = autocast if self .precision == 'autocast' else nullcontext
281274
@@ -353,8 +346,11 @@ def process_image(image,seed):
353346 f'Error running RealESRGAN - Your image was not upscaled.\n { e } '
354347 )
355348 if image_callback is not None :
356- image_callback (image , seed , upscaled = True )
357- else : # no callback passed, so we simply replace old image with rescaled one
349+ if save_original :
350+ image_callback (image , seed )
351+ else :
352+ image_callback (image , seed , upscaled = True )
353+ else : # no callback passed, so we simply replace old image with rescaled one
358354 result [0 ] = image
359355
360356 except KeyboardInterrupt :
@@ -436,7 +432,7 @@ def _img2img(
436432 width ,
437433 height ,
438434 strength ,
439- callback , # Currently not implemented for img2img
435+ callback , # Currently not implemented for img2img
440436 ):
441437 """
442438 An infinite iterator of images from the prompt and the initial image
@@ -445,13 +441,13 @@ def _img2img(
445441 # PLMS sampler not supported yet, so ignore previous sampler
446442 if self .sampler_name != 'ddim' :
447443 print (
448- f"sampler '{ self .sampler_name } ' is not yet supported. Using DDM sampler"
444+ f"sampler '{ self .sampler_name } ' is not yet supported. Using DDIM sampler"
449445 )
450446 sampler = DDIMSampler (self .model , device = self .device )
451447 else :
452448 sampler = self .sampler
453449
454- init_image = self ._load_img (init_img ,width ,height ).to (self .device )
450+ init_image = self ._load_img (init_img , width , height ).to (self .device )
455451 with precision_scope (self .device .type ):
456452 init_latent = self .model .get_first_stage_encoding (
457453 self .model .encode_first_stage (init_image )
@@ -514,7 +510,8 @@ def _sample_to_image(self, samples):
514510 x_samples = self .model .decode_first_stage (samples )
515511 x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
516512 if len (x_samples ) != 1 :
517- raise Exception (f'expected to get a single image, but got { len (x_samples )} ' )
513+ raise Exception (
514+ f'expected to get a single image, but got { len (x_samples )} ' )
518515 x_sample = 255.0 * rearrange (
519516 x_samples [0 ].cpu ().numpy (), 'c h w -> h w c'
520517 )
@@ -545,8 +542,9 @@ def load_model(self):
545542 self .model .cond_stage_model .device = self .device
546543 except AttributeError :
547544 import traceback
548- print ('Error loading model. Only the CUDA backend is supported' ,file = sys .stderr )
549- print (traceback .format_exc (),file = sys .stderr )
545+ print (
546+ 'Error loading model. Only the CUDA backend is supported' , file = sys .stderr )
547+ print (traceback .format_exc (), file = sys .stderr )
550548 raise SystemExit
551549
552550 self ._set_sampler ()
@@ -606,10 +604,26 @@ def _load_img(self, path, width, height):
606604 print (f'image path = { path } , cwd = { os .getcwd ()} ' )
607605 with Image .open (path ) as img :
608606 image = img .convert ('RGB' )
609- print (f'loaded input image of size { image .width } x{ image .height } from { path } ' )
607+ print (
608+ f'loaded input image of size { image .width } x{ image .height } from { path } ' )
610609
611- image = InitImageResizer (image ).resize (width ,height )
612- print (f'resized input image to size { image .width } x{ image .height } ' )
610+ from ldm .dream .image_util import InitImageResizer
611+ if width == self .width and height == self .height :
612+ new_image_width , new_image_height , resize_needed = self ._resolution_check (
613+ image .width , image .height )
614+ else :
615+ if height == self .height :
616+ new_image_width , new_image_height , resize_needed = self ._resolution_check (
617+ width , image .height )
618+ if width == self .width :
619+ new_image_width , new_image_height , resize_needed = self ._resolution_check (
620+ image .width , height )
621+ else :
622+ image = InitImageResizer (image ).resize (width , height )
623+ resize_needed = False
624+ if resize_needed :
625+ image = InitImageResizer (image ).resize (
626+ new_image_width , new_image_height )
613627
614628 image = np .array (image ).astype (np .float32 ) / 255.0
615629 image = image [None ].transpose (0 , 3 , 1 , 2 )
@@ -633,7 +647,7 @@ def _split_weighted_subprompts(text):
633647 prompt = text [:idx ]
634648 remaining -= idx
635649 # remove from main text
636- text = text [idx + 1 :]
650+ text = text [idx + 1 :]
637651 # find value for weight
638652 if ' ' in text :
639653 idx = text .index (' ' ) # first occurence
@@ -651,7 +665,7 @@ def _split_weighted_subprompts(text):
651665 weight = 1.0
652666 # remove from main text
653667 remaining -= idx
654- text = text [idx + 1 :]
668+ text = text [idx + 1 :]
655669 # append the sub-prompt and its weight
656670 prompts .append (prompt )
657671 weights .append (weight )
@@ -662,9 +676,9 @@ def _split_weighted_subprompts(text):
662676 weights .append (1.0 )
663677 remaining = 0
664678 return prompts , weights
665-
666- # shows how the prompt is tokenized
667- # usually tokens have '</w>' to indicate end-of-word,
679+
680+ # shows how the prompt is tokenized
681+ # usually tokens have '</w>' to indicate end-of-word,
668682 # but for readability it has been replaced with ' '
669683 def _log_tokenization (self , text ):
670684 if not self .log_tokenization :
@@ -674,15 +688,31 @@ def _log_tokenization(self, text):
674688 discarded = ""
675689 usedTokens = 0
676690 totalTokens = len (tokens )
677- for i in range (0 ,totalTokens ):
678- token = tokens [i ].replace ('</w>' ,' ' )
691+ for i in range (0 , totalTokens ):
692+ token = tokens [i ].replace ('</w>' , ' ' )
679693 # alternate color
680694 s = (usedTokens % 6 ) + 1
681695 if i < self .model .cond_stage_model .max_length :
682696 tokenized = tokenized + f"\x1b [0;3{ s } ;40m{ token } "
683697 usedTokens += 1
684- else : # over max token length
698+ else : # over max token length
685699 discarded = discarded + f"\x1b [0;3{ s } ;40m{ token } "
686700 print (f"\n Tokens ({ usedTokens } ):\n { tokenized } \x1b [0m" )
687701 if discarded != "" :
688- print (f"Tokens Discarded ({ totalTokens - usedTokens } ):\n { discarded } \x1b [0m" )
702+ print (
703+ f"Tokens Discarded ({ totalTokens - usedTokens } ):\n { discarded } \x1b [0m" )
704+
705+ def _resolution_check (self , width , height , log = False ):
706+ resize_needed = False
707+ w , h = map (
708+ lambda x : x - x % 64 , (width , height )
709+ ) # resize to integer multiple of 64
710+ if h != height or w != width :
711+ if log :
712+ print (
713+ f'>> Provided width and height must be multiples of 64. Auto-resizing to { w } x{ h } '
714+ )
715+ height = h
716+ width = w
717+ resize_needed = True
718+ return width , height , resize_needed
0 commit comments