@@ -483,22 +483,19 @@ def _get_uc_and_c(self, prompt, skip_normalize):
483483
484484 uc = self .model .get_learned_conditioning (['' ])
485485
486- # weighted sub-prompts
487- subprompts , weights = T2I ._split_weighted_subprompts (prompt )
488- if len (subprompts ) > 1 :
486+ # get weighted sub-prompts
487+ weighted_subprompts = T2I ._split_weighted_subprompts (prompt , skip_normalize )
488+
489+ if len (weighted_subprompts ) > 1 :
489490 # i dont know if this is correct.. but it works
490491 c = torch .zeros_like (uc )
491- # get total weight for normalizing
492- totalWeight = sum (weights )
493492 # normalize each "sub prompt" and add it
494- for i in range (0 , len (subprompts )):
495- weight = weights [i ]
496- if not skip_normalize :
497- weight = weight / totalWeight
498- self ._log_tokenization (subprompts [i ])
493+ for i in range (0 , len (weighted_subprompts )):
494+ subprompt , weight = weighted_subprompts [i ]
495+ self ._log_tokenization (subprompt )
499496 c = torch .add (
500497 c ,
501- self .model .get_learned_conditioning ([subprompts [ i ] ]),
498+ self .model .get_learned_conditioning ([subprompt ]),
502499 alpha = weight ,
503500 )
504501 else : # just standard 1 prompt
@@ -630,55 +627,39 @@ def _load_img(self, path, width, height):
630627 image = torch .from_numpy (image )
631628 return 2.0 * image - 1.0
632629
633- def _split_weighted_subprompts (text ):
630+ def _split_weighted_subprompts (text , skip_normalize = False ):
634631 """
635632 grabs all text up to the first occurrence of ':'
636633 uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
637634 if ':' has no value defined, defaults to 1.0
638635 repeats until no text remaining
639636 """
640- remaining = len (text )
641- prompts = []
642- weights = []
643- while remaining > 0 :
644- if ':' in text :
645- idx = text .index (':' ) # first occurrence from start
646- # grab up to index as sub-prompt
647- prompt = text [:idx ]
648- remaining -= idx
649- # remove from main text
650- text = text [idx + 1 :]
651- # find value for weight
652- if ' ' in text :
653- idx = text .index (' ' ) # first occurence
654- else : # no space, read to end
655- idx = len (text )
656- if idx != 0 :
657- try :
658- weight = float (text [:idx ])
659- except : # couldn't treat as float
660- print (
661- f"Warning: '{ text [:idx ]} ' is not a value, are you missing a space?"
662- )
663- weight = 1.0
664- else : # no value found
665- weight = 1.0
666- # remove from main text
667- remaining -= idx
668- text = text [idx + 1 :]
669- # append the sub-prompt and its weight
670- prompts .append (prompt )
671- weights .append (weight )
672- else : # no : found
673- if len (text ) > 0 : # there is still text though
674- # take remainder as weight 1
675- prompts .append (text )
676- weights .append (1.0 )
677- remaining = 0
678- return prompts , weights
679-
680- # shows how the prompt is tokenized
681- # usually tokens have '</w>' to indicate end-of-word,
637+ prompt_parser = re .compile ("""
638+ (?P<prompt> # capture group for 'prompt'
639+ (?:\\ \:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
640+ ) # end 'prompt'
641+ (?: # non-capture group
642+ :+ # match one or more ':' characters
643+ (?P<weight> # capture group for 'weight'
644+ -?\d+(?:\.\d+)? # match positive or negative integer or decimal number
645+ )? # end weight capture group, make optional
646+ \s* # strip spaces after weight
647+ | # OR
648+ $ # else, if no ':' then match end of line
649+ ) # end non-capture group
650+ """ , re .VERBOSE )
651+ parsed_prompts = [(match .group ("prompt" ).replace ("\\ :" , ":" ), float (match .group ("weight" ) or 1 )) for match in re .finditer (prompt_parser , text )]
652+ if skip_normalize :
653+ return parsed_prompts
654+ weight_sum = sum (map (lambda x : x [1 ], parsed_prompts ))
655+ if weight_sum == 0 :
656+ print ("Warning: Subprompt weights add up to zero. Discarding and using even weights instead." )
657+ equal_weight = 1 / len (parsed_prompts )
658+ return [(x [0 ], equal_weight ) for x in parsed_prompts ]
659+ return [(x [0 ], x [1 ] / weight_sum ) for x in parsed_prompts ]
660+
661+ # shows how the prompt is tokenized
662+ # usually tokens have '</w>' to indicate end-of-word,
682663 # but for readability it has been replaced with ' '
683664 def _log_tokenization (self , text ):
684665 if not self .log_tokenization :
0 commit comments