@@ -780,13 +780,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
780
780
text_input_ids = text_input_ids_list [i ]
781
781
782
782
prompt_embeds = text_encoder (
783
- text_input_ids .to (text_encoder .device ),
784
- output_hidden_states = True ,
783
+ text_input_ids .to (text_encoder .device ), output_hidden_states = True , return_dict = False
785
784
)
786
785
787
786
# We are only ALWAYS interested in the pooled output of the final text encoder
788
787
pooled_prompt_embeds = prompt_embeds [0 ]
789
- prompt_embeds = prompt_embeds . hidden_states [- 2 ]
788
+ prompt_embeds = prompt_embeds [ - 1 ] [- 2 ]
790
789
bs_embed , seq_len , _ = prompt_embeds .shape
791
790
prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
792
791
prompt_embeds_list .append (prompt_embeds )
@@ -1429,7 +1428,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1429
1428
timesteps ,
1430
1429
prompt_embeds_input ,
1431
1430
added_cond_kwargs = unet_added_conditions ,
1432
- ).sample
1431
+ return_dict = False ,
1432
+ )[0 ]
1433
1433
else :
1434
1434
unet_added_conditions = {"time_ids" : add_time_ids .repeat (elems_to_repeat_time_ids , 1 )}
1435
1435
prompt_embeds , pooled_prompt_embeds = encode_prompt (
@@ -1443,8 +1443,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1443
1443
)
1444
1444
prompt_embeds_input = prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 , 1 )
1445
1445
model_pred = unet (
1446
- noisy_model_input , timesteps , prompt_embeds_input , added_cond_kwargs = unet_added_conditions
1447
- ).sample
1446
+ noisy_model_input ,
1447
+ timesteps ,
1448
+ prompt_embeds_input ,
1449
+ added_cond_kwargs = unet_added_conditions ,
1450
+ return_dict = False ,
1451
+ )[0 ]
1448
1452
1449
1453
# Get the target for loss depending on the prediction type
1450
1454
if noise_scheduler .config .prediction_type == "epsilon" :
0 commit comments