3939t2i = T2I(model = <path> // models/ldm/stable-diffusion-v1/model.ckpt
4040 config = <path> // configs/stable-diffusion/v1-inference.yaml
4141 iterations = <integer> // how many times to run the sampling (1)
42- batch_size = <integer> // how many images to generate per sampling (1)
4342 steps = <integer> // 50
4443 seed = <integer> // current system time
4544 sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
@@ -98,7 +97,6 @@ class T2I:
9897 model
9998 config
10099 iterations
101- batch_size
102100 steps
103101 seed
104102 sampler_name
@@ -116,7 +114,6 @@ class T2I:
116114
117115 def __init__ (
118116 self ,
119- batch_size = 1 ,
120117 iterations = 1 ,
121118 steps = 50 ,
122119 seed = None ,
@@ -138,7 +135,6 @@ def __init__(
138135 latent_diffusion_weights = False ,
139136 device = 'cuda' ,
140137 ):
141- self .batch_size = batch_size
142138 self .iterations = iterations
143139 self .width = width
144140 self .height = height
@@ -174,9 +170,7 @@ def prompt2png(self, prompt, outdir, **kwargs):
174170 Optional named arguments are the same as those passed to T2I and prompt2image()
175171 """
176172 results = self .prompt2image (prompt , ** kwargs )
177- pngwriter = PngWriter (
178- outdir , prompt , kwargs .get ('batch_size' , self .batch_size )
179- )
173+ pngwriter = PngWriter (outdir , prompt )
180174 for r in results :
181175 pngwriter .write_image (r [0 ], r [1 ])
182176 return pngwriter .files_written
@@ -196,7 +190,6 @@ def prompt2image(
196190 self ,
197191 # these are common
198192 prompt ,
199- batch_size = None ,
200193 iterations = None ,
201194 steps = None ,
202195 seed = None ,
@@ -222,8 +215,7 @@ def prompt2image(
222215 ldm.prompt2image() is the common entry point for txt2img() and img2img()
223216 It takes the following arguments:
224217 prompt // prompt string (no default)
225- iterations // iterations (1); image count=iterations x batch_size
226- batch_size // images per iteration (1)
218+ iterations // iterations (1); image count=iterations
227219 steps // refinement steps per iteration
228220 seed // seed for random number generator
229221 width // width of image, in multiples of 64 (512)
@@ -258,7 +250,6 @@ def process_image(image,seed):
258250 height = height or self .height
259251 cfg_scale = cfg_scale or self .cfg_scale
260252 ddim_eta = ddim_eta or self .ddim_eta
261- batch_size = batch_size or self .batch_size
262253 iterations = iterations or self .iterations
263254 strength = strength or self .strength
264255 self .log_tokenization = log_tokenization
@@ -297,7 +288,6 @@ def process_image(image,seed):
297288 images_iterator = self ._img2img (
298289 prompt ,
299290 precision_scope = scope ,
300- batch_size = batch_size ,
301291 steps = steps ,
302292 cfg_scale = cfg_scale ,
303293 ddim_eta = ddim_eta ,
@@ -312,7 +302,6 @@ def process_image(image,seed):
312302 images_iterator = self ._txt2img (
313303 prompt ,
314304 precision_scope = scope ,
315- batch_size = batch_size ,
316305 steps = steps ,
317306 cfg_scale = cfg_scale ,
318307 ddim_eta = ddim_eta ,
@@ -325,11 +314,10 @@ def process_image(image,seed):
325314 with scope (self .device .type ), self .model .ema_scope ():
326315 for n in trange (iterations , desc = 'Generating' ):
327316 seed_everything (seed )
328- iter_images = next (images_iterator )
329- for image in iter_images :
330- results .append ([image , seed ])
331- if image_callback is not None :
332- image_callback (image , seed )
317+ image = next (images_iterator )
318+ results .append ([image , seed ])
319+ if image_callback is not None :
320+ image_callback (image , seed )
333321 seed = self ._new_seed ()
334322
335323 if upscale is not None or gfpgan_strength > 0 :
@@ -399,7 +387,6 @@ def _txt2img(
399387 self ,
400388 prompt ,
401389 precision_scope ,
402- batch_size ,
403390 steps ,
404391 cfg_scale ,
405392 ddim_eta ,
@@ -415,31 +402,30 @@ def _txt2img(
415402 sampler = self .sampler
416403
417404 while True :
418- uc , c = self ._get_uc_and_c (prompt , batch_size , skip_normalize )
405+ uc , c = self ._get_uc_and_c (prompt , skip_normalize )
419406 shape = [
420407 self .latent_channels ,
421408 height // self .downsampling_factor ,
422409 width // self .downsampling_factor ,
423410 ]
424411 samples , _ = sampler .sample (
412+ batch_size = 1 ,
425413 S = steps ,
426414 conditioning = c ,
427- batch_size = batch_size ,
428415 shape = shape ,
429416 verbose = False ,
430417 unconditional_guidance_scale = cfg_scale ,
431418 unconditional_conditioning = uc ,
432419 eta = ddim_eta ,
433420 img_callback = callback
434421 )
435- yield self ._samples_to_images (samples )
422+ yield self ._sample_to_image (samples )
436423
437424 @torch .no_grad ()
438425 def _img2img (
439426 self ,
440427 prompt ,
441428 precision_scope ,
442- batch_size ,
443429 steps ,
444430 cfg_scale ,
445431 ddim_eta ,
@@ -464,7 +450,6 @@ def _img2img(
464450 sampler = self .sampler
465451
466452 init_image = self ._load_img (init_img ,width ,height ).to (self .device )
467- init_image = repeat (init_image , '1 ... -> b ...' , b = batch_size )
468453 with precision_scope (self .device .type ):
469454 init_latent = self .model .get_first_stage_encoding (
470455 self .model .encode_first_stage (init_image )
@@ -478,11 +463,11 @@ def _img2img(
478463 # print(f"target t_enc is {t_enc} steps")
479464
480465 while True :
481- uc , c = self ._get_uc_and_c (prompt , batch_size , skip_normalize )
466+ uc , c = self ._get_uc_and_c (prompt , skip_normalize )
482467
483468 # encode (scaled latent)
484469 z_enc = sampler .stochastic_encode (
485- init_latent , torch .tensor ([t_enc ] * batch_size ).to (self .device )
470+ init_latent , torch .tensor ([t_enc ]).to (self .device )
486471 )
487472 # decode it
488473 samples = sampler .decode (
@@ -493,12 +478,12 @@ def _img2img(
493478 unconditional_guidance_scale = cfg_scale ,
494479 unconditional_conditioning = uc ,
495480 )
496- yield self ._samples_to_images (samples )
481+ yield self ._sample_to_image (samples )
497482
498483 # TODO: does this actually need to run every loop? does anything in it vary by random seed?
499- def _get_uc_and_c (self , prompt , batch_size , skip_normalize ):
484+ def _get_uc_and_c (self , prompt , skip_normalize ):
500485
501- uc = self .model .get_learned_conditioning (batch_size * ['' ])
486+ uc = self .model .get_learned_conditioning (['' ])
502487
503488 # weighted sub-prompts
504489 subprompts , weights = T2I ._split_weighted_subprompts (prompt )
@@ -515,27 +500,23 @@ def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
515500 self ._log_tokenization (subprompts [i ])
516501 c = torch .add (
517502 c ,
518- self .model .get_learned_conditioning (
519- batch_size * [subprompts [i ]]
520- ),
503+ self .model .get_learned_conditioning ([subprompts [i ]]),
521504 alpha = weight ,
522505 )
523506 else : # just standard 1 prompt
524507 self ._log_tokenization (prompt )
525- c = self .model .get_learned_conditioning (batch_size * [prompt ])
508+ c = self .model .get_learned_conditioning ([prompt ])
526509 return (uc , c )
527510
528- def _samples_to_images (self , samples ):
511+ def _sample_to_image (self , samples ):
529512 x_samples = self .model .decode_first_stage (samples )
530513 x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
531- images = list ()
532- for x_sample in x_samples :
533- x_sample = 255.0 * rearrange (
534- x_sample .cpu ().numpy (), 'c h w -> h w c'
535- )
536- image = Image .fromarray (x_sample .astype (np .uint8 ))
537- images .append (image )
538- return images
514+ if len (x_samples ) != 1 :
515+ raise Exception (f'expected to get a single image, but got { len (x_samples )} ' )
516+ x_sample = 255.0 * rearrange (
517+ x_samples [0 ].cpu ().numpy (), 'c h w -> h w c'
518+ )
519+ return Image .fromarray (x_sample .astype (np .uint8 ))
539520
540521 def _new_seed (self ):
541522 self .seed = random .randrange (0 , np .iinfo (np .uint32 ).max )
0 commit comments