File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed
src/diffusers/pipelines/pixart_alpha Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -156,6 +156,8 @@ def encode_prompt(
156
156
mask_feature: (bool, defaults to `True`):
157
157
If `True`, the function will mask the text embeddings.
158
158
"""
159
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
160
+
159
161
if device is None :
160
162
device = self ._execution_device
161
163
@@ -253,7 +255,7 @@ def encode_prompt(
253
255
negative_prompt_embeds = None
254
256
255
257
# Perform additional masking.
256
- if mask_feature and prompt_embeds is None and negative_prompt_embeds is None :
258
+ if mask_feature and not embeds_initially_provided :
257
259
prompt_embeds = prompt_embeds .unsqueeze (1 )
258
260
masked_prompt_embeds , keep_indices = self .mask_text_embeddings (prompt_embeds , prompt_embeds_attention_mask )
259
261
masked_prompt_embeds = masked_prompt_embeds .squeeze (1 )
Original file line number Diff line number Diff line change @@ -120,7 +120,6 @@ def test_save_load_optional_components(self):
120
120
"generator" : generator ,
121
121
"num_inference_steps" : num_inference_steps ,
122
122
"output_type" : output_type ,
123
- "mask_feature" : False ,
124
123
}
125
124
126
125
# set all optional components to None
@@ -155,7 +154,6 @@ def test_save_load_optional_components(self):
155
154
"generator" : generator ,
156
155
"num_inference_steps" : num_inference_steps ,
157
156
"output_type" : output_type ,
158
- "mask_feature" : False ,
159
157
}
160
158
161
159
output_loaded = pipe_loaded (** inputs )[0 ]
You can’t perform that action at this time.
0 commit comments