@@ -521,7 +521,11 @@ def infotext(iteration=0, position_in_batch=0):
521
521
shared .state .job = f"Batch { n + 1 } out of { p .n_iter } "
522
522
523
523
with devices .autocast ():
524
- samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength )
524
+ # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
525
+ if isinstance (p , StableDiffusionProcessingTxt2Img ):
526
+ samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength , n = n )
527
+ else :
528
+ samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength )
525
529
526
530
samples_ddim = samples_ddim .to (devices .dtype_vae )
527
531
x_samples_ddim = decode_first_stage (p .sd_model , samples_ddim )
@@ -649,7 +653,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
649
653
self .truncate_x = int (self .firstphase_width - firstphase_width_truncated ) // opt_f
650
654
self .truncate_y = int (self .firstphase_height - firstphase_height_truncated ) // opt_f
651
655
652
- def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength ):
656
+ def sample (self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , n = 0 ):
653
657
self .sampler = sd_samplers .create_sampler_with_index (sd_samplers .samplers , self .sampler_index , self .sd_model )
654
658
655
659
if not self .enable_hr :
@@ -685,6 +689,15 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
685
689
686
690
samples = self .sd_model .get_first_stage_encoding (self .sd_model .encode_first_stage (decoded_samples ))
687
691
692
+ # Save a copy of the image/s before doing highres fix, if applicable.
693
+ if opts .save and not self .do_not_save_samples and opts .save_images_before_highres_fix :
694
+ for i in range (self .batch_size ):
695
+ # This batch's ith image.
696
+ img = sd_samplers .sample_to_image (samples , i )
697
+ # Index that accounts for both batch size and batch count.
698
+ ind = i + self .batch_size * n
699
+ images .save_image (img , self .outpath_samples , "" , self .all_seeds [ind ], self .all_prompts [ind ], opts .samples_format , suffix = f"-before-highres-fix" )
700
+
688
701
shared .state .nextjob ()
689
702
690
703
self .sampler = sd_samplers .create_sampler_with_index (sd_samplers .samplers , self .sampler_index , self .sd_model )
0 commit comments