Skip to content

Commit 1714816

Browse files
authored
remove support for batch_size from dream.py (#227)
* remove dream.py support for batch_size * expect to get a single image
1 parent b5565d2 commit 1714816

File tree

6 files changed

+27
-56
lines changed

6 files changed

+27
-56
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,13 @@ face enhancement (see previous section):
297297
```
298298
dream> a cute child playing hopscotch -G0.5
299299
[...]
300-
outputs/img-samples/000039.3498014304.png: "a cute child playing hopscotch" -s50 -b1 -W512 -H512 -C7.5 -mk_lms -S3498014304
300+
outputs/img-samples/000039.3498014304.png: "a cute child playing hopscotch" -s50 -W512 -H512 -C7.5 -mk_lms -S3498014304
301301
302302
# I wonder what it will look like if I bump up the steps and set facial enhancement to full strength?
303303
dream> a cute child playing hopscotch -G1.0 -s100 -S -1
304304
reusing previous seed 3498014304
305305
[...]
306-
outputs/img-samples/000040.3498014304.png: "a cute child playing hopscotch" -G1.0 -s100 -b1 -W512 -H512 -C7.5 -mk_lms -S3498014304
306+
outputs/img-samples/000040.3498014304.png: "a cute child playing hopscotch" -G1.0 -s100 -W512 -H512 -C7.5 -mk_lms -S3498014304
307307
```
308308

309309
## Weighted Prompts

ldm/dream/pngwriter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def normalize_prompt(self):
117117
switches = list()
118118
switches.append(f'"{opt.prompt}"')
119119
switches.append(f'-s{opt.steps or t2i.steps}')
120-
switches.append(f'-b{opt.batch_size or t2i.batch_size}')
121120
switches.append(f'-W{opt.width or t2i.width}')
122121
switches.append(f'-H{opt.height or t2i.height}')
123122
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')

ldm/dream/readline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def _path_completions(self, text, state, extensions):
8989
'--steps','-s',
9090
'--seed','-S',
9191
'--iterations','-n',
92-
'--batch_size','-b',
9392
'--width','-W','--height','-H',
9493
'--cfg_scale','-C',
9594
'--grid','-g',

ldm/dream/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ def image_progress(sample, step):
140140
# since rendering images is moderately expensive, only render every 5th image
141141
# and don't bother with the last one, since it'll render anyway
142142
if progress_images and step % 5 == 0 and step < steps - 1:
143-
images = self.model._samples_to_images(sample)
144-
image = images[0]
143+
image = self.model._sample_to_image(sample)
145144
step_writer.write_image(image, seed) # TODO PngWriter to return path
146145
url = step_writer.filepath
147146
self.wfile.write(bytes(json.dumps(

ldm/simplet2i.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
4040
config = <path> // configs/stable-diffusion/v1-inference.yaml
4141
iterations = <integer> // how many times to run the sampling (1)
42-
batch_size = <integer> // how many images to generate per sampling (1)
4342
steps = <integer> // 50
4443
seed = <integer> // current system time
4544
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
@@ -98,7 +97,6 @@ class T2I:
9897
model
9998
config
10099
iterations
101-
batch_size
102100
steps
103101
seed
104102
sampler_name
@@ -116,7 +114,6 @@ class T2I:
116114

117115
def __init__(
118116
self,
119-
batch_size=1,
120117
iterations=1,
121118
steps=50,
122119
seed=None,
@@ -138,7 +135,6 @@ def __init__(
138135
latent_diffusion_weights=False,
139136
device='cuda',
140137
):
141-
self.batch_size = batch_size
142138
self.iterations = iterations
143139
self.width = width
144140
self.height = height
@@ -174,9 +170,7 @@ def prompt2png(self, prompt, outdir, **kwargs):
174170
Optional named arguments are the same as those passed to T2I and prompt2image()
175171
"""
176172
results = self.prompt2image(prompt, **kwargs)
177-
pngwriter = PngWriter(
178-
outdir, prompt, kwargs.get('batch_size', self.batch_size)
179-
)
173+
pngwriter = PngWriter(outdir, prompt)
180174
for r in results:
181175
pngwriter.write_image(r[0], r[1])
182176
return pngwriter.files_written
@@ -196,7 +190,6 @@ def prompt2image(
196190
self,
197191
# these are common
198192
prompt,
199-
batch_size=None,
200193
iterations=None,
201194
steps=None,
202195
seed=None,
@@ -222,8 +215,7 @@ def prompt2image(
222215
ldm.prompt2image() is the common entry point for txt2img() and img2img()
223216
It takes the following arguments:
224217
prompt // prompt string (no default)
225-
iterations // iterations (1); image count=iterations x batch_size
226-
batch_size // images per iteration (1)
218+
iterations // iterations (1); image count=iterations
227219
steps // refinement steps per iteration
228220
seed // seed for random number generator
229221
width // width of image, in multiples of 64 (512)
@@ -258,7 +250,6 @@ def process_image(image,seed):
258250
height = height or self.height
259251
cfg_scale = cfg_scale or self.cfg_scale
260252
ddim_eta = ddim_eta or self.ddim_eta
261-
batch_size = batch_size or self.batch_size
262253
iterations = iterations or self.iterations
263254
strength = strength or self.strength
264255
self.log_tokenization = log_tokenization
@@ -297,7 +288,6 @@ def process_image(image,seed):
297288
images_iterator = self._img2img(
298289
prompt,
299290
precision_scope=scope,
300-
batch_size=batch_size,
301291
steps=steps,
302292
cfg_scale=cfg_scale,
303293
ddim_eta=ddim_eta,
@@ -312,7 +302,6 @@ def process_image(image,seed):
312302
images_iterator = self._txt2img(
313303
prompt,
314304
precision_scope=scope,
315-
batch_size=batch_size,
316305
steps=steps,
317306
cfg_scale=cfg_scale,
318307
ddim_eta=ddim_eta,
@@ -325,11 +314,10 @@ def process_image(image,seed):
325314
with scope(self.device.type), self.model.ema_scope():
326315
for n in trange(iterations, desc='Generating'):
327316
seed_everything(seed)
328-
iter_images = next(images_iterator)
329-
for image in iter_images:
330-
results.append([image, seed])
331-
if image_callback is not None:
332-
image_callback(image, seed)
317+
image = next(images_iterator)
318+
results.append([image, seed])
319+
if image_callback is not None:
320+
image_callback(image, seed)
333321
seed = self._new_seed()
334322

335323
if upscale is not None or gfpgan_strength > 0:
@@ -399,7 +387,6 @@ def _txt2img(
399387
self,
400388
prompt,
401389
precision_scope,
402-
batch_size,
403390
steps,
404391
cfg_scale,
405392
ddim_eta,
@@ -415,31 +402,30 @@ def _txt2img(
415402
sampler = self.sampler
416403

417404
while True:
418-
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
405+
uc, c = self._get_uc_and_c(prompt, skip_normalize)
419406
shape = [
420407
self.latent_channels,
421408
height // self.downsampling_factor,
422409
width // self.downsampling_factor,
423410
]
424411
samples, _ = sampler.sample(
412+
batch_size=1,
425413
S=steps,
426414
conditioning=c,
427-
batch_size=batch_size,
428415
shape=shape,
429416
verbose=False,
430417
unconditional_guidance_scale=cfg_scale,
431418
unconditional_conditioning=uc,
432419
eta=ddim_eta,
433420
img_callback=callback
434421
)
435-
yield self._samples_to_images(samples)
422+
yield self._sample_to_image(samples)
436423

437424
@torch.no_grad()
438425
def _img2img(
439426
self,
440427
prompt,
441428
precision_scope,
442-
batch_size,
443429
steps,
444430
cfg_scale,
445431
ddim_eta,
@@ -464,7 +450,6 @@ def _img2img(
464450
sampler = self.sampler
465451

466452
init_image = self._load_img(init_img,width,height).to(self.device)
467-
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
468453
with precision_scope(self.device.type):
469454
init_latent = self.model.get_first_stage_encoding(
470455
self.model.encode_first_stage(init_image)
@@ -478,11 +463,11 @@ def _img2img(
478463
# print(f"target t_enc is {t_enc} steps")
479464

480465
while True:
481-
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
466+
uc, c = self._get_uc_and_c(prompt, skip_normalize)
482467

483468
# encode (scaled latent)
484469
z_enc = sampler.stochastic_encode(
485-
init_latent, torch.tensor([t_enc] * batch_size).to(self.device)
470+
init_latent, torch.tensor([t_enc]).to(self.device)
486471
)
487472
# decode it
488473
samples = sampler.decode(
@@ -493,12 +478,12 @@ def _img2img(
493478
unconditional_guidance_scale=cfg_scale,
494479
unconditional_conditioning=uc,
495480
)
496-
yield self._samples_to_images(samples)
481+
yield self._sample_to_image(samples)
497482

498483
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
499-
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
484+
def _get_uc_and_c(self, prompt, skip_normalize):
500485

501-
uc = self.model.get_learned_conditioning(batch_size * [''])
486+
uc = self.model.get_learned_conditioning([''])
502487

503488
# weighted sub-prompts
504489
subprompts, weights = T2I._split_weighted_subprompts(prompt)
@@ -515,27 +500,23 @@ def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
515500
self._log_tokenization(subprompts[i])
516501
c = torch.add(
517502
c,
518-
self.model.get_learned_conditioning(
519-
batch_size * [subprompts[i]]
520-
),
503+
self.model.get_learned_conditioning([subprompts[i]]),
521504
alpha=weight,
522505
)
523506
else: # just standard 1 prompt
524507
self._log_tokenization(prompt)
525-
c = self.model.get_learned_conditioning(batch_size * [prompt])
508+
c = self.model.get_learned_conditioning([prompt])
526509
return (uc, c)
527510

528-
def _samples_to_images(self, samples):
511+
def _sample_to_image(self, samples):
529512
x_samples = self.model.decode_first_stage(samples)
530513
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
531-
images = list()
532-
for x_sample in x_samples:
533-
x_sample = 255.0 * rearrange(
534-
x_sample.cpu().numpy(), 'c h w -> h w c'
535-
)
536-
image = Image.fromarray(x_sample.astype(np.uint8))
537-
images.append(image)
538-
return images
514+
if len(x_samples) != 1:
515+
raise Exception(f'expected to get a single image, but got {len(x_samples)}')
516+
x_sample = 255.0 * rearrange(
517+
x_samples[0].cpu().numpy(), 'c h w -> h w c'
518+
)
519+
return Image.fromarray(x_sample.astype(np.uint8))
539520

540521
def _new_seed(self):
541522
self.seed = random.randrange(0, np.iinfo(np.uint32).max)

scripts/dream.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
199199

200200
# Here is where the images are actually generated!
201201
try:
202-
file_writer = PngWriter(current_outdir, normalized_prompt, opt.batch_size)
202+
file_writer = PngWriter(current_outdir, normalized_prompt)
203203
callback = file_writer.write_image if individual_images else None
204204
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
205205
results = (
@@ -419,13 +419,6 @@ def create_cmd_parser():
419419
default=1,
420420
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
421421
)
422-
parser.add_argument(
423-
'-b',
424-
'--batch_size',
425-
type=int,
426-
default=1,
427-
help='Number of images to produce per sampling (will not provide seeds for individual images!)',
428-
)
429422
parser.add_argument(
430423
'-W', '--width', type=int, help='Image width, multiple of 64'
431424
)

0 commit comments

Comments
 (0)