2727from ldm .models .diffusion .plms import PLMSSampler
2828from ldm .models .diffusion .ksampler import KSampler
2929from ldm .dream .pngwriter import PngWriter
30- from ldm .dream .image_util import InitImageResizer
3130
3231"""Simplified text to image API for stable diffusion/latent diffusion
3332
@@ -261,16 +260,9 @@ def process_image(image,seed):
261260 assert (
262261 0.0 <= strength <= 1.0
263262 ), 'can only work with strength in [0.0, 1.0]'
264- w , h = map (
265- lambda x : x - x % 64 , (width , height )
266- ) # resize to integer multiple of 64
267263
268- if h != height or w != width :
269- print (
270- f'Height and width must be multiples of 64. Resizing to { h } x{ w } .'
271- )
272- height = h
273- width = w
264+ if not (width == self .width and height == self .height ):
265+ width , height , _ = self ._resolution_check (width , height , log = True )
274266
275267 scope = autocast if self .precision == 'autocast' else nullcontext
276268
@@ -352,7 +344,7 @@ def process_image(image,seed):
352344 image_callback (image , seed )
353345 else :
354346 image_callback (image , seed , upscaled = True )
355- else : # no callback passed, so we simply replace old image with rescaled one
347+ else : # no callback passed, so we simply replace old image with rescaled one
356348 result [0 ] = image
357349
358350 except KeyboardInterrupt :
@@ -434,7 +426,7 @@ def _img2img(
434426 width ,
435427 height ,
436428 strength ,
437- callback , # Currently not implemented for img2img
429+ callback , # Currently not implemented for img2img
438430 ):
439431 """
440432 An infinite iterator of images from the prompt and the initial image
@@ -443,13 +435,13 @@ def _img2img(
443435 # PLMS sampler not supported yet, so ignore previous sampler
444436 if self .sampler_name != 'ddim' :
445437 print (
446- f"sampler '{ self .sampler_name } ' is not yet supported. Using DDM sampler"
438+ f"sampler '{ self .sampler_name } ' is not yet supported. Using DDIM sampler"
447439 )
448440 sampler = DDIMSampler (self .model , device = self .device )
449441 else :
450442 sampler = self .sampler
451443
452- init_image = self ._load_img (init_img ,width ,height ).to (self .device )
444+ init_image = self ._load_img (init_img , width , height ).to (self .device )
453445 with precision_scope (self .device .type ):
454446 init_latent = self .model .get_first_stage_encoding (
455447 self .model .encode_first_stage (init_image )
@@ -512,7 +504,8 @@ def _sample_to_image(self, samples):
512504 x_samples = self .model .decode_first_stage (samples )
513505 x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
514506 if len (x_samples ) != 1 :
515- raise Exception (f'expected to get a single image, but got { len (x_samples )} ' )
507+ raise Exception (
508+ f'expected to get a single image, but got { len (x_samples )} ' )
516509 x_sample = 255.0 * rearrange (
517510 x_samples [0 ].cpu ().numpy (), 'c h w -> h w c'
518511 )
@@ -547,8 +540,9 @@ def load_model(self):
547540 self .model .cond_stage_model .device = self .device
548541 except AttributeError :
549542 import traceback
550- print ('Error loading model. Only the CUDA backend is supported' ,file = sys .stderr )
551- print (traceback .format_exc (),file = sys .stderr )
543+ print (
544+ 'Error loading model. Only the CUDA backend is supported' , file = sys .stderr )
545+ print (traceback .format_exc (), file = sys .stderr )
552546 raise SystemExit
553547
554548 self ._set_sampler ()
@@ -608,10 +602,26 @@ def _load_img(self, path, width, height):
608602 print (f'image path = { path } , cwd = { os .getcwd ()} ' )
609603 with Image .open (path ) as img :
610604 image = img .convert ('RGB' )
611- print (f'loaded input image of size { image .width } x{ image .height } from { path } ' )
605+ print (
606+ f'loaded input image of size { image .width } x{ image .height } from { path } ' )
612607
613- image = InitImageResizer (image ).resize (width ,height )
614- print (f'resized input image to size { image .width } x{ image .height } ' )
608+ from ldm .dream .image_util import InitImageResizer
609+ if width == self .width and height == self .height :
610+ new_image_width , new_image_height , resize_needed = self ._resolution_check (
611+ image .width , image .height )
612+ else :
613+ if height == self .height :
614+ new_image_width , new_image_height , resize_needed = self ._resolution_check (
615+ width , image .height )
616+ if width == self .width :
617+ new_image_width , new_image_height , resize_needed = self ._resolution_check (
618+ image .width , height )
619+ else :
620+ image = InitImageResizer (image ).resize (width , height )
621+ resize_needed = False
622+ if resize_needed :
623+ image = InitImageResizer (image ).resize (
624+ new_image_width , new_image_height )
615625
616626 image = np .array (image ).astype (np .float32 ) / 255.0
617627 image = image [None ].transpose (0 , 3 , 1 , 2 )
@@ -635,7 +645,7 @@ def _split_weighted_subprompts(text):
635645 prompt = text [:idx ]
636646 remaining -= idx
637647 # remove from main text
638- text = text [idx + 1 :]
648+ text = text [idx + 1 :]
639649 # find value for weight
640650 if ' ' in text :
641651 idx = text .index (' ' ) # first occurence
@@ -653,7 +663,7 @@ def _split_weighted_subprompts(text):
653663 weight = 1.0
654664 # remove from main text
655665 remaining -= idx
656- text = text [idx + 1 :]
666+ text = text [idx + 1 :]
657667 # append the sub-prompt and its weight
658668 prompts .append (prompt )
659669 weights .append (weight )
@@ -664,9 +674,9 @@ def _split_weighted_subprompts(text):
664674 weights .append (1.0 )
665675 remaining = 0
666676 return prompts , weights
667-
668- # shows how the prompt is tokenized
669- # usually tokens have '</w>' to indicate end-of-word,
677+
678+ # shows how the prompt is tokenized
679+ # usually tokens have '</w>' to indicate end-of-word,
670680 # but for readability it has been replaced with ' '
671681 def _log_tokenization (self , text ):
672682 if not self .log_tokenization :
@@ -676,15 +686,31 @@ def _log_tokenization(self, text):
676686 discarded = ""
677687 usedTokens = 0
678688 totalTokens = len (tokens )
679- for i in range (0 ,totalTokens ):
680- token = tokens [i ].replace ('</w>' ,' ' )
689+ for i in range (0 , totalTokens ):
690+ token = tokens [i ].replace ('</w>' , ' ' )
681691 # alternate color
682692 s = (usedTokens % 6 ) + 1
683693 if i < self .model .cond_stage_model .max_length :
684694 tokenized = tokenized + f"\x1b [0;3{ s } ;40m{ token } "
685695 usedTokens += 1
686- else : # over max token length
696+ else : # over max token length
687697 discarded = discarded + f"\x1b [0;3{ s } ;40m{ token } "
688698 print (f"\n Tokens ({ usedTokens } ):\n { tokenized } \x1b [0m" )
689699 if discarded != "" :
690- print (f"Tokens Discarded ({ totalTokens - usedTokens } ):\n { discarded } \x1b [0m" )
700+ print (
701+ f"Tokens Discarded ({ totalTokens - usedTokens } ):\n { discarded } \x1b [0m" )
702+
703+ def _resolution_check (self , width , height , log = False ):
704+ resize_needed = False
705+ w , h = map (
706+ lambda x : x - x % 64 , (width , height )
707+ ) # resize to integer multiple of 64
708+ if h != height or w != width :
709+ if log :
710+ print (
711+ f'>> Provided width and height must be multiples of 64. Auto-resizing to { w } x{ h } '
712+ )
713+ height = h
714+ width = w
715+ resize_needed = True
716+ return width , height , resize_needed
0 commit comments