@@ -1253,21 +1253,21 @@ def __call__(
12531253            ]
12541254            controlnet_keep .append (keeps [0 ] if  isinstance (controlnet , ControlNetModel ) else  keeps )
12551255
1256-         control_images  =  control_image  if  isinstance (control_image , list ) else  [control_image ]
1257-         for  i , single_image  in  enumerate (control_images ):
1258-             if  self .do_classifier_free_guidance :
1259-                 single_image  =  single_image .chunk (2 )[0 ]
1260- 
1261-             if  self .do_perturbed_attention_guidance :
1262-                 single_image  =  self ._prepare_perturbed_attention_guidance (
1263-                     single_image , single_image , self .do_classifier_free_guidance 
1264-                 )
1265-             elif  self .do_classifier_free_guidance :
1266-                 single_image  =  torch .cat ([single_image ] *  2 )
1267-             single_image  =  single_image .to (device )
1268-             control_images [i ] =  single_image 
1269- 
1270-         control_image  =  control_images  if  isinstance (control_image , list ) else  control_images [0 ]
1256+         #  control_images = control_image if isinstance(control_image, list) else [control_image]
1257+         #  for i, single_image in enumerate(control_images):
1258+         #      if self.do_classifier_free_guidance:
1259+         #          single_image = single_image.chunk(2)[0]
1260+ 
1261+         #      if self.do_perturbed_attention_guidance:
1262+         #          single_image = self._prepare_perturbed_attention_guidance(
1263+         #              single_image, single_image, self.do_classifier_free_guidance
1264+         #          )
1265+         #      elif self.do_classifier_free_guidance:
1266+         #          single_image = torch.cat([single_image] * 2)
1267+         #      single_image = single_image.to(device)
1268+         #      control_images[i] = single_image
1269+ 
1270+         # control_image = control_images if isinstance(control_image, list) else control_images[0]
12711271
12721272        prompt_embeds  =  prompt_embeds .to (device )
12731273
@@ -1285,12 +1285,22 @@ def __call__(
12851285
12861286        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
12871287            for  i , t  in  enumerate (timesteps ):
1288+                 if  self .interrupt :
1289+                     continue 
1290+ 
12881291                # expand the latents if we are doing classifier free guidance 
1289-                 latent_model_input  =  torch .cat ([latents ] *  ( prompt_embeds . shape [ 0 ]  //  latents . shape [ 0 ])) 
1292+                 latent_model_input  =  torch .cat ([latents ] *  2 )  if   self . do_classifier_free_guidance   else  latents 
12901293                latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
12911294
12921295                # controlnet(s) inference 
1293-                 control_model_input  =  latent_model_input 
1296+                 if  guess_mode  and  self .do_classifier_free_guidance :
1297+                     # Infer ControlNet only for the conditional batch. 
1298+                     control_model_input  =  latents 
1299+                     control_model_input  =  self .scheduler .scale_model_input (control_model_input , t )
1300+                     controlnet_prompt_embeds  =  prompt_embeds .chunk (2 )[1 ]
1301+                 else :
1302+                     control_model_input  =  latent_model_input 
1303+                     controlnet_prompt_embeds  =  prompt_embeds 
12941304
12951305                if  isinstance (controlnet_keep [i ], list ):
12961306                    cond_scale  =  [c  *  s  for  c , s  in  zip (controlnet_conditioning_scale , controlnet_keep [i ])]
@@ -1299,16 +1309,23 @@ def __call__(
12991309                    if  isinstance (controlnet_cond_scale , list ):
13001310                        controlnet_cond_scale  =  controlnet_cond_scale [0 ]
13011311                    cond_scale  =  controlnet_cond_scale  *  controlnet_keep [i ]
1312+ 
13021313                down_block_res_samples , mid_block_res_sample  =  self .controlnet (
13031314                    control_model_input ,
13041315                    t ,
13051316                    encoder_hidden_states = controlnet_prompt_embeds ,
13061317                    controlnet_cond = control_image ,
13071318                    conditioning_scale = cond_scale ,
1308-                     guess_mode = False ,
1319+                     guess_mode = guess_mode ,
13091320                    return_dict = False ,
13101321                )
13111322
1323+                 if  guess_mode  and  self .do_classifier_free_guidance :
1324+                     # Inferred ControlNet only for the conditional batch. 
1325+                     # To apply the output of ControlNet to both the unconditional and conditional batches, 
1326+                     # add 0 to the unconditional batch to keep it unchanged. 
1327+                     down_block_res_samples  =  [torch .cat ([torch .zeros_like (d ), d ]) for  d  in  down_block_res_samples ]
1328+                     mid_block_res_sample  =  torch .cat ([torch .zeros_like (mid_block_res_sample ), mid_block_res_sample ])
13121329
13131330                # predict the noise residual 
13141331                noise_pred  =  self .unet (
0 commit comments