@@ -907,8 +907,8 @@ def _encode_prompt_with_llama(
907907    else :
908908        dtype  =  text_encoder .dtype 
909909
910-     prompt_embeds  =  outputs .hidden_states [1 :]. to ( dtype = dtype ,  device = device ) 
911-     prompt_embeds  =  torch .stack (prompt_embeds , dim = 0 )
910+     prompt_embeds  =  outputs .hidden_states [1 :]
911+     prompt_embeds  =  torch .stack (prompt_embeds , dim = 0 ). to ( dtype = dtype ,  device = device ) 
912912    _ , _ , seq_len , dim  =  prompt_embeds .shape 
913913
914914    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 
@@ -1060,6 +1060,8 @@ def encode_prompt(
10601060        attention_mask = attention_mask_list [1 ] if  attention_mask_list  else  None ,
10611061    )
10621062
1063+     print ("t5_prompt_embeds" ,t5_prompt_embeds .shape )
1064+     print ("llama3_prompt_embeds" ,llama3_prompt_embeds .shape )
10631065    prompt_embeds  =  [t5_prompt_embeds , llama3_prompt_embeds ]
10641066
10651067    return  prompt_embeds , pooled_prompt_embeds 
@@ -1431,7 +1433,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14311433            prompt_embeds , pooled_prompt_embeds  =  encode_prompt (
14321434                text_encoders , tokenizers , prompt , args .max_sequence_length 
14331435            )
1434-             prompt_embeds  =  prompt_embeds .to (accelerator .device )
1436+             prompt_embeds [0 ] =  prompt_embeds [0 ].to (accelerator .device )
1437+             prompt_embeds [1 ] =  prompt_embeds [1 ].to (accelerator .device )
14351438            pooled_prompt_embeds  =  pooled_prompt_embeds .to (accelerator .device )
14361439        return  prompt_embeds , pooled_prompt_embeds 
14371440
@@ -1587,7 +1590,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15871590                if  train_dataset .custom_instance_prompts :
15881591                    prompt_embeds , pooled_prompt_embeds  =  compute_text_embeddings (prompts , text_encoders , tokenizers )
15891592                else :
1590-                     prompt_embeds  =  prompt_embeds .repeat (len (prompts ), 1 , 1 )
1593+                     prompt_embeds [0 ] =  prompt_embeds [0 ].repeat (len (prompts ), 1 , 1 )
1594+                     prompt_embeds [1 ] =  prompt_embeds [1 ].repeat (1 , len (prompts ), 1 , 1 )
15911595                    pooled_prompt_embeds  =  pooled_prompt_embeds .repeat (len (prompts ), 1 )
15921596                # Convert images to latent space 
15931597                if  args .cache_latents :
0 commit comments