@@ -1591,7 +1591,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15911591                # encode batch prompts when custom prompts are provided for each image - 
15921592                if  train_dataset .custom_instance_prompts :
15931593                    prompt_embeds , pooled_prompt_embeds  =  compute_text_embeddings (prompts , text_encoders , tokenizers )
1594- 
1594+                 else :
1595+                     prompt_embeds  =  prompt_embeds .repeat (len (prompts ), 1 , 1 )
1596+                     pooled_prompt_embeds  =  pooled_prompt_embeds .repeat (len (prompts ), 1 )
15951597                # Convert images to latent space 
15961598                if  args .cache_latents :
15971599                    model_input  =  latents_cache [step ].sample ()
@@ -1646,12 +1648,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16461648                # Predict the noise residual 
16471649                model_pred  =  transformer (
16481650                    hidden_states = noisy_model_input ,
1649-                     encoder_hidden_states = prompt_embeds .repeat (len (prompts ), 1 , 1 )
1650-                     if  not  train_dataset .custom_instance_prompts 
1651-                     else  prompt_embeds ,
1652-                     pooled_embeds = pooled_prompt_embeds .repeat (len (prompts ), 1 )
1653-                     if  not  train_dataset .custom_instance_prompts 
1654-                     else  pooled_prompt_embeds ,
1651+                     encoder_hidden_states = prompt_embeds ,
1652+                     pooled_embeds = pooled_prompt_embeds ,
16551653                    timestep = timesteps ,
16561654                    img_sizes = img_sizes ,
16571655                    img_ids = img_ids ,
0 commit comments