Skip to content

Commit c52ba1b

Browse files
authored
feat: simplify and enhance prompt weight splitting (#258)
* feat: simplify and enhance prompt weight splitting * fix: don't shadow the prompt variable * feat: enable backslash-escaped colons in prompts
1 parent d022d0d commit c52ba1b

File tree

1 file changed

+32
-51
lines changed

1 file changed

+32
-51
lines changed

ldm/simplet2i.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -487,22 +487,19 @@ def _get_uc_and_c(self, prompt, skip_normalize):
487487

488488
uc = self.model.get_learned_conditioning([''])
489489

490-
# weighted sub-prompts
491-
subprompts, weights = T2I._split_weighted_subprompts(prompt)
492-
if len(subprompts) > 1:
490+
# get weighted sub-prompts
491+
weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize)
492+
493+
if len(weighted_subprompts) > 1:
493494
# i dont know if this is correct.. but it works
494495
c = torch.zeros_like(uc)
495-
# get total weight for normalizing
496-
totalWeight = sum(weights)
497496
# normalize each "sub prompt" and add it
498-
for i in range(0, len(subprompts)):
499-
weight = weights[i]
500-
if not skip_normalize:
501-
weight = weight / totalWeight
502-
self._log_tokenization(subprompts[i])
497+
for i in range(0, len(weighted_subprompts)):
498+
subprompt, weight = weighted_subprompts[i]
499+
self._log_tokenization(subprompt)
503500
c = torch.add(
504501
c,
505-
self.model.get_learned_conditioning([subprompts[i]]),
502+
self.model.get_learned_conditioning([subprompt]),
506503
alpha=weight,
507504
)
508505
else: # just standard 1 prompt
@@ -616,52 +613,36 @@ def _load_img(self, path, width, height):
616613
image = torch.from_numpy(image)
617614
return 2.0 * image - 1.0
618615

619-
def _split_weighted_subprompts(text):
616+
def _split_weighted_subprompts(text, skip_normalize=False):
620617
"""
621618
grabs all text up to the first occurrence of ':'
622619
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
623620
if ':' has no value defined, defaults to 1.0
624621
repeats until no text remaining
625622
"""
626-
remaining = len(text)
627-
prompts = []
628-
weights = []
629-
while remaining > 0:
630-
if ':' in text:
631-
idx = text.index(':') # first occurrence from start
632-
# grab up to index as sub-prompt
633-
prompt = text[:idx]
634-
remaining -= idx
635-
# remove from main text
636-
text = text[idx + 1 :]
637-
# find value for weight
638-
if ' ' in text:
639-
idx = text.index(' ') # first occurence
640-
else: # no space, read to end
641-
idx = len(text)
642-
if idx != 0:
643-
try:
644-
weight = float(text[:idx])
645-
except: # couldn't treat as float
646-
print(
647-
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
648-
)
649-
weight = 1.0
650-
else: # no value found
651-
weight = 1.0
652-
# remove from main text
653-
remaining -= idx
654-
text = text[idx + 1 :]
655-
# append the sub-prompt and its weight
656-
prompts.append(prompt)
657-
weights.append(weight)
658-
else: # no : found
659-
if len(text) > 0: # there is still text though
660-
# take remainder as weight 1
661-
prompts.append(text)
662-
weights.append(1.0)
663-
remaining = 0
664-
return prompts, weights
623+
prompt_parser = re.compile("""
624+
(?P<prompt> # capture group for 'prompt'
625+
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
626+
) # end 'prompt'
627+
(?: # non-capture group
628+
:+ # match one or more ':' characters
629+
(?P<weight> # capture group for 'weight'
630+
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
631+
)? # end weight capture group, make optional
632+
\s* # strip spaces after weight
633+
| # OR
634+
$ # else, if no ':' then match end of line
635+
) # end non-capture group
636+
""", re.VERBOSE)
637+
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
638+
if skip_normalize:
639+
return parsed_prompts
640+
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
641+
if weight_sum == 0:
642+
print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
643+
equal_weight = 1 / len(parsed_prompts)
644+
return [(x[0], equal_weight) for x in parsed_prompts]
645+
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
665646

666647
# shows how the prompt is tokenized
667648
# usually tokens have '</w>' to indicate end-of-word,

0 commit comments

Comments
 (0)