Skip to content

Commit 9ad7920

Browse files
committed
Merge branch 'main' of github.com:lstein/stable-diffusion into main
2 parents 0be2351 + ed51339 commit 9ad7920

File tree

3 files changed

+44
-56
lines changed

3 files changed

+44
-56
lines changed

ldm/simplet2i.py

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -483,22 +483,19 @@ def _get_uc_and_c(self, prompt, skip_normalize):
483483

484484
uc = self.model.get_learned_conditioning([''])
485485

486-
# weighted sub-prompts
487-
subprompts, weights = T2I._split_weighted_subprompts(prompt)
488-
if len(subprompts) > 1:
486+
# get weighted sub-prompts
487+
weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize)
488+
489+
if len(weighted_subprompts) > 1:
489490
# i dont know if this is correct.. but it works
490491
c = torch.zeros_like(uc)
491-
# get total weight for normalizing
492-
totalWeight = sum(weights)
493492
# normalize each "sub prompt" and add it
494-
for i in range(0, len(subprompts)):
495-
weight = weights[i]
496-
if not skip_normalize:
497-
weight = weight / totalWeight
498-
self._log_tokenization(subprompts[i])
493+
for i in range(0, len(weighted_subprompts)):
494+
subprompt, weight = weighted_subprompts[i]
495+
self._log_tokenization(subprompt)
499496
c = torch.add(
500497
c,
501-
self.model.get_learned_conditioning([subprompts[i]]),
498+
self.model.get_learned_conditioning([subprompt]),
502499
alpha=weight,
503500
)
504501
else: # just standard 1 prompt
@@ -630,55 +627,39 @@ def _load_img(self, path, width, height):
630627
image = torch.from_numpy(image)
631628
return 2.0 * image - 1.0
632629

633-
def _split_weighted_subprompts(text):
630+
def _split_weighted_subprompts(text, skip_normalize=False):
634631
"""
635632
grabs all text up to the first occurrence of ':'
636633
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
637634
if ':' has no value defined, defaults to 1.0
638635
repeats until no text remaining
639636
"""
640-
remaining = len(text)
641-
prompts = []
642-
weights = []
643-
while remaining > 0:
644-
if ':' in text:
645-
idx = text.index(':') # first occurrence from start
646-
# grab up to index as sub-prompt
647-
prompt = text[:idx]
648-
remaining -= idx
649-
# remove from main text
650-
text = text[idx + 1:]
651-
# find value for weight
652-
if ' ' in text:
653-
idx = text.index(' ') # first occurence
654-
else: # no space, read to end
655-
idx = len(text)
656-
if idx != 0:
657-
try:
658-
weight = float(text[:idx])
659-
except: # couldn't treat as float
660-
print(
661-
f"Warning: '{text[:idx]}' is not a value, are you missing a space?"
662-
)
663-
weight = 1.0
664-
else: # no value found
665-
weight = 1.0
666-
# remove from main text
667-
remaining -= idx
668-
text = text[idx + 1:]
669-
# append the sub-prompt and its weight
670-
prompts.append(prompt)
671-
weights.append(weight)
672-
else: # no : found
673-
if len(text) > 0: # there is still text though
674-
# take remainder as weight 1
675-
prompts.append(text)
676-
weights.append(1.0)
677-
remaining = 0
678-
return prompts, weights
679-
680-
# shows how the prompt is tokenized
681-
# usually tokens have '</w>' to indicate end-of-word,
637+
prompt_parser = re.compile("""
638+
(?P<prompt> # capture group for 'prompt'
639+
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
640+
) # end 'prompt'
641+
(?: # non-capture group
642+
:+ # match one or more ':' characters
643+
(?P<weight> # capture group for 'weight'
644+
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
645+
)? # end weight capture group, make optional
646+
\s* # strip spaces after weight
647+
| # OR
648+
$ # else, if no ':' then match end of line
649+
) # end non-capture group
650+
""", re.VERBOSE)
651+
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
652+
if skip_normalize:
653+
return parsed_prompts
654+
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
655+
if weight_sum == 0:
656+
print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
657+
equal_weight = 1 / len(parsed_prompts)
658+
return [(x[0], equal_weight) for x in parsed_prompts]
659+
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
660+
661+
# shows how the prompt is tokenized
662+
# usually tokens have '</w>' to indicate end-of-word,
682663
# but for readability it has been replaced with ' '
683664
def _log_tokenization(self, text):
684665
if not self.log_tokenization:

scripts/dream.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def main():
2929
width = 512
3030
height = 512
3131
config = 'configs/stable-diffusion/v1-inference.yaml'
32-
weights = 'models/ldm/stable-diffusion-v1/model.ckpt'
32+
if '.ckpt' in opt.weights:
33+
weights = opt.weights
34+
else:
35+
weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt'
3336

3437
print('* Initializing, be patient...\n')
3538
sys.path.append('.')
@@ -418,6 +421,11 @@ def create_argv_parser():
418421
action='store_true',
419422
help='Start in web server mode.',
420423
)
424+
parser.add_argument(
425+
'--weights',
426+
default='model',
427+
help='Indicates the Stable Diffusion model to use.',
428+
)
421429
return parser
422430

423431

static/dream_web/index.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ async function generateSubmit(form) {
9898
appendOutput(data.url, data.seed, data.config);
9999
progressEle.setAttribute('value', 0);
100100
progressEle.setAttribute('max', totalSteps);
101-
progressImageEle.src = BLANK_IMAGE_URL;
102101
} else if (data.event === 'upscaling-started') {
103102
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
104103
document.getElementById("scaling-inprocess-message").style.display = "block";

0 commit comments

Comments
 (0)