@@ -199,7 +199,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask = No
199
199
def init (self , all_prompts , all_seeds , all_subseeds ):
200
200
pass
201
201
202
- def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ):
202
+ def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ):
203
203
raise NotImplementedError ()
204
204
205
205
def close (self ):
@@ -521,11 +521,7 @@ def infotext(iteration=0, position_in_batch=0):
521
521
shared .state .job = f"Batch { n + 1 } out of { p .n_iter } "
522
522
523
523
with devices .autocast ():
524
- # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
525
- if isinstance (p , StableDiffusionProcessingTxt2Img ):
526
- samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength , n = n )
527
- else :
528
- samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength )
524
+ samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength , prompts = prompts )
529
525
530
526
samples_ddim = samples_ddim .to (devices .dtype_vae )
531
527
x_samples_ddim = decode_first_stage (p .sd_model , samples_ddim )
@@ -653,7 +649,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
653
649
self .truncate_x = int (self .firstphase_width - firstphase_width_truncated ) // opt_f
654
650
self .truncate_y = int (self .firstphase_height - firstphase_height_truncated ) // opt_f
655
651
656
- def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , n = 0 ):
652
+ def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ):
657
653
self .sampler = sd_samplers .create_sampler_with_index (sd_samplers .samplers , self .sampler_index , self .sd_model )
658
654
659
655
if not self .enable_hr :
@@ -666,9 +662,21 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
666
662
667
663
samples = samples [:, :, self .truncate_y // 2 :samples .shape [2 ]- self .truncate_y // 2 , self .truncate_x // 2 :samples .shape [3 ]- self .truncate_x // 2 ]
668
664
665
+ """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
666
+ def save_intermediate (image , index ):
667
+ if not opts .save or self .do_not_save_samples or not opts .save_images_before_highres_fix :
668
+ return
669
+
670
+ if not isinstance (image , Image .Image ):
671
+ image = sd_samplers .sample_to_image (image , index )
672
+
673
+ images .save_image (image , self .outpath_samples , "" , seeds [index ], prompts [index ], opts .samples_format , suffix = "-before-highres-fix" )
674
+
669
675
if opts .use_scale_latent_for_hires_fix :
670
676
samples = torch .nn .functional .interpolate (samples , size = (self .height // opt_f , self .width // opt_f ), mode = "bilinear" )
671
677
678
+ for i in range (samples .shape [0 ]):
679
+ save_intermediate (samples , i )
672
680
else :
673
681
decoded_samples = decode_first_stage (self .sd_model , samples )
674
682
lowres_samples = torch .clamp ((decoded_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
@@ -678,6 +686,9 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
678
686
x_sample = 255. * np .moveaxis (x_sample .cpu ().numpy (), 0 , 2 )
679
687
x_sample = x_sample .astype (np .uint8 )
680
688
image = Image .fromarray (x_sample )
689
+
690
+ save_intermediate (image , i )
691
+
681
692
image = images .resize_image (0 , image , self .width , self .height )
682
693
image = np .array (image ).astype (np .float32 ) / 255.0
683
694
image = np .moveaxis (image , 2 , 0 )
@@ -689,15 +700,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
689
700
690
701
samples = self .sd_model .get_first_stage_encoding (self .sd_model .encode_first_stage (decoded_samples ))
691
702
692
- # Save a copy of the image/s before doing highres fix, if applicable.
693
- if opts .save and not self .do_not_save_samples and opts .save_images_before_highres_fix :
694
- for i in range (self .batch_size ):
695
- # This batch's ith image.
696
- img = sd_samplers .sample_to_image (samples , i )
697
- # Index that accounts for both batch size and batch count.
698
- ind = i + self .batch_size * n
699
- images .save_image (img , self .outpath_samples , "" , self .all_seeds [ind ], self .all_prompts [ind ], opts .samples_format , suffix = f"-before-highres-fix" )
700
-
701
703
shared .state .nextjob ()
702
704
703
705
self .sampler = sd_samplers .create_sampler_with_index (sd_samplers .samplers , self .sampler_index , self .sd_model )
@@ -844,8 +846,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
844
846
845
847
self .image_conditioning = self .img2img_image_conditioning (image , self .init_latent , self .image_mask )
846
848
847
-
848
- def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ):
849
+ def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ):
849
850
x = create_random_tensors ([opt_C , self .height // opt_f , self .width // opt_f ], seeds = seeds , subseeds = subseeds , subseed_strength = self .subseed_strength , seed_resize_from_h = self .seed_resize_from_h , seed_resize_from_w = self .seed_resize_from_w , p = self )
850
851
851
852
samples = self .sampler .sample_img2img (self , self .init_latent , x , conditioning , unconditional_conditioning , image_conditioning = self .image_conditioning )
@@ -856,4 +857,4 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
856
857
del x
857
858
devices .torch_gc ()
858
859
859
- return samples
860
+ return samples
0 commit comments