Skip to content

Commit 0433b3d

Browse files
Add Warning When Image Is Too Large (#271)
* Add Warning When Image Is Too Large * fix incomprehensible formatting introduced by "blue" Co-authored-by: Lincoln Stein <[email protected]>
1 parent 4b560b5 commit 0433b3d

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

ldm/simplet2i.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ldm.models.diffusion.plms import PLMSSampler
2828
from ldm.models.diffusion.ksampler import KSampler
2929
from ldm.dream.pngwriter import PngWriter
30-
from ldm.dream.image_util import InitImageResizer
3130
from ldm.dream.devices import choose_torch_device
3231

3332
"""Simplified text to image API for stable diffusion/latent diffusion
@@ -159,7 +158,7 @@ def __init__(
159158

160159
# for VRAM usage statistics
161160
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
162-
161+
163162
if seed is None:
164163
self.seed = self._new_seed()
165164
else:
@@ -178,7 +177,8 @@ def prompt2png(self, prompt, outdir, **kwargs):
178177
outputs = []
179178
for image, seed in results:
180179
name = f'{prefix}.{seed}.png'
181-
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
180+
path = pngwriter.save_image_and_prompt_to_png(
181+
image, f'{prompt} -S{seed}', name)
182182
outputs.append([path, seed])
183183
return outputs
184184

@@ -488,7 +488,8 @@ def _get_uc_and_c(self, prompt, skip_normalize):
488488
uc = self.model.get_learned_conditioning([''])
489489

490490
# get weighted sub-prompts
491-
weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize)
491+
weighted_subprompts = T2I._split_weighted_subprompts(
492+
prompt, skip_normalize)
492493

493494
if len(weighted_subprompts) > 1:
494495
# i dont know if this is correct.. but it works
@@ -531,7 +532,7 @@ def load_model(self):
531532
if self.model is None:
532533
seed_everything(self.seed)
533534
try:
534-
config = OmegaConf.load(self.config)
535+
config = OmegaConf.load(self.config)
535536
self.device = self._get_device()
536537
model = self._load_model_from_config(config, self.weights)
537538
if self.embedding_path is not None:
@@ -621,7 +622,7 @@ def _load_img(self, path, width, height):
621622
image.width, height)
622623
else:
623624
image = InitImageResizer(image).resize(width, height)
624-
resize_needed=False
625+
resize_needed = False
625626
if resize_needed:
626627
image = InitImageResizer(image).resize(
627628
new_image_width, new_image_height)
@@ -652,18 +653,20 @@ def _split_weighted_subprompts(text, skip_normalize=False):
652653
$ # else, if no ':' then match end of line
653654
) # end non-capture group
654655
""", re.VERBOSE)
655-
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
656+
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
657+
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
656658
if skip_normalize:
657659
return parsed_prompts
658660
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
659661
if weight_sum == 0:
660-
print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
662+
print(
663+
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
661664
equal_weight = 1 / len(parsed_prompts)
662665
return [(x[0], equal_weight) for x in parsed_prompts]
663666
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
664-
665-
# shows how the prompt is tokenized
666-
# usually tokens have '</w>' to indicate end-of-word,
667+
668+
# shows how the prompt is tokenized
669+
# usually tokens have '</w>' to indicate end-of-word,
667670
# but for readability it has been replaced with ' '
668671
def _log_tokenization(self, text):
669672
if not self.log_tokenization:
@@ -700,4 +703,8 @@ def _resolution_check(self, width, height, log=False):
700703
height = h
701704
width = w
702705
resize_needed = True
706+
707+
if (width * height) > (self.width * self.height):
708+
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
709+
703710
return width, height, resize_needed

0 commit comments

Comments
 (0)