@@ -487,22 +487,19 @@ def _get_uc_and_c(self, prompt, skip_normalize):
487487
488488 uc = self .model .get_learned_conditioning (['' ])
489489
490- # weighted sub-prompts
491- subprompts , weights = T2I ._split_weighted_subprompts (prompt )
492- if len (subprompts ) > 1 :
490+ # get weighted sub-prompts
491+ weighted_subprompts = T2I ._split_weighted_subprompts (prompt , skip_normalize )
492+
493+ if len (weighted_subprompts ) > 1 :
493494 # i dont know if this is correct.. but it works
494495 c = torch .zeros_like (uc )
495- # get total weight for normalizing
496- totalWeight = sum (weights )
497496 # normalize each "sub prompt" and add it
498- for i in range (0 , len (subprompts )):
499- weight = weights [i ]
500- if not skip_normalize :
501- weight = weight / totalWeight
502- self ._log_tokenization (subprompts [i ])
497+ for i in range (0 , len (weighted_subprompts )):
498+ subprompt , weight = weighted_subprompts [i ]
499+ self ._log_tokenization (subprompt )
503500 c = torch .add (
504501 c ,
505- self .model .get_learned_conditioning ([subprompts [ i ] ]),
502+ self .model .get_learned_conditioning ([subprompt ]),
506503 alpha = weight ,
507504 )
508505 else : # just standard 1 prompt
@@ -616,52 +613,36 @@ def _load_img(self, path, width, height):
616613 image = torch .from_numpy (image )
617614 return 2.0 * image - 1.0
618615
619- def _split_weighted_subprompts (text ):
616+ def _split_weighted_subprompts (text , skip_normalize = False ):
620617 """
621618 grabs all text up to the first occurrence of ':'
622619 uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
623620 if ':' has no value defined, defaults to 1.0
624621 repeats until no text remaining
625622 """
626- remaining = len (text )
627- prompts = []
628- weights = []
629- while remaining > 0 :
630- if ':' in text :
631- idx = text .index (':' ) # first occurrence from start
632- # grab up to index as sub-prompt
633- prompt = text [:idx ]
634- remaining -= idx
635- # remove from main text
636- text = text [idx + 1 :]
637- # find value for weight
638- if ' ' in text :
639- idx = text .index (' ' ) # first occurence
640- else : # no space, read to end
641- idx = len (text )
642- if idx != 0 :
643- try :
644- weight = float (text [:idx ])
645- except : # couldn't treat as float
646- print (
647- f"Warning: '{ text [:idx ]} ' is not a value, are you missing a space?"
648- )
649- weight = 1.0
650- else : # no value found
651- weight = 1.0
652- # remove from main text
653- remaining -= idx
654- text = text [idx + 1 :]
655- # append the sub-prompt and its weight
656- prompts .append (prompt )
657- weights .append (weight )
658- else : # no : found
659- if len (text ) > 0 : # there is still text though
660- # take remainder as weight 1
661- prompts .append (text )
662- weights .append (1.0 )
663- remaining = 0
664- return prompts , weights
623+ prompt_parser = re .compile ("""
624+ (?P<prompt> # capture group for 'prompt'
625+ (?:\\ \:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
626+ ) # end 'prompt'
627+ (?: # non-capture group
628+ :+ # match one or more ':' characters
629+ (?P<weight> # capture group for 'weight'
630+ -?\d+(?:\.\d+)? # match positive or negative integer or decimal number
631+ )? # end weight capture group, make optional
632+ \s* # strip spaces after weight
633+ | # OR
634+ $ # else, if no ':' then match end of line
635+ ) # end non-capture group
636+ """ , re .VERBOSE )
637+ parsed_prompts = [(match .group ("prompt" ).replace ("\\ :" , ":" ), float (match .group ("weight" ) or 1 )) for match in re .finditer (prompt_parser , text )]
638+ if skip_normalize :
639+ return parsed_prompts
640+ weight_sum = sum (map (lambda x : x [1 ], parsed_prompts ))
641+ if weight_sum == 0 :
642+ print ("Warning: Subprompt weights add up to zero. Discarding and using even weights instead." )
643+ equal_weight = 1 / len (parsed_prompts )
644+ return [(x [0 ], equal_weight ) for x in parsed_prompts ]
645+ return [(x [0 ], x [1 ] / weight_sum ) for x in parsed_prompts ]
665646
666647 # shows how the prompt is tokenized
667648 # usually tokens have '</w>' to indicate end-of-word,
0 commit comments