@@ -1234,6 +1234,13 @@ def __call__(
12341234 device ,
12351235 generator ,
12361236 )
1237+
1238+ timestep_cond = None
1239+ if self .unet .config .time_cond_proj_dim is not None :
1240+ guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
1241+ timestep_cond = self .get_guidance_scale_embedding (
1242+ guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
1243+ ).to (device = device , dtype = latents .dtype )
12371244
12381245 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
12391246 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
@@ -1317,6 +1324,7 @@ def __call__(
13171324 latent_model_input ,
13181325 t ,
13191326 encoder_hidden_states = prompt_embeds ,
1327+ timestep_cond = timestep_cond ,
13201328 cross_attention_kwargs = self .cross_attention_kwargs ,
13211329 down_block_additional_residuals = down_block_res_samples ,
13221330 mid_block_additional_residual = mid_block_res_sample ,
@@ -1344,6 +1352,7 @@ def __call__(
13441352
13451353 latents = callback_outputs .pop ("latents" , latents )
13461354 prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
1355+ negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
13471356 # call the callback, if provided
13481357 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
13491358 progress_bar .update ()
0 commit comments