@@ -653,6 +653,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
653
653
654
654
if opts .use_scale_latent_for_hires_fix :
655
655
samples = torch .nn .functional .interpolate (samples , size = (self .height // opt_f , self .width // opt_f ), mode = "bilinear" )
656
+ image_conditioning = self .txt2img_image_conditioning (samples )
656
657
657
658
else :
658
659
decoded_samples = decode_first_stage (self .sd_model , samples )
@@ -674,6 +675,12 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
674
675
675
676
samples = self .sd_model .get_first_stage_encoding (self .sd_model .encode_first_stage (decoded_samples ))
676
677
678
+ image_conditioning = self .img2img_image_conditioning (
679
+ decoded_samples ,
680
+ samples ,
681
+ decoded_samples .new_ones (decoded_samples .shape [0 ], 1 , decoded_samples .shape [2 ], decoded_samples .shape [3 ])
682
+ )
683
+
677
684
shared .state .nextjob ()
678
685
679
686
self .sampler = sd_samplers .create_sampler_with_index (sd_samplers .samplers , self .sampler_index , self .sd_model )
@@ -684,11 +691,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
684
691
x = None
685
692
devices .torch_gc ()
686
693
687
- image_conditioning = self .img2img_image_conditioning (
688
- decoded_samples ,
689
- samples ,
690
- decoded_samples .new_ones (decoded_samples .shape [0 ], 1 , decoded_samples .shape [2 ], decoded_samples .shape [3 ])
691
- )
692
694
samples = self .sampler .sample_img2img (self , samples , noise , conditioning , unconditional_conditioning , steps = self .steps , image_conditioning = image_conditioning )
693
695
694
696
return samples
0 commit comments