Skip to content

Commit 4d8a908

Browse files
committed
[PixArt-Alpha] fix mask feature condition. (#5695)
* fix mask feature condition. * debug * remove identical test * set correct * Empty-Commit
1 parent 96829f0 commit 4d8a908

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def encode_prompt(
156156
mask_feature: (bool, defaults to `True`):
157157
If `True`, the function will mask the text embeddings.
158158
"""
159+
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
160+
159161
if device is None:
160162
device = self._execution_device
161163

@@ -253,7 +255,7 @@ def encode_prompt(
253255
negative_prompt_embeds = None
254256

255257
# Perform additional masking.
256-
if mask_feature and prompt_embeds is None and negative_prompt_embeds is None:
258+
if mask_feature and not embeds_initially_provided:
257259
prompt_embeds = prompt_embeds.unsqueeze(1)
258260
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
259261
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)

tests/pipelines/pixart/test_pixart.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def test_save_load_optional_components(self):
120120
"generator": generator,
121121
"num_inference_steps": num_inference_steps,
122122
"output_type": output_type,
123-
"mask_feature": False,
124123
}
125124

126125
# set all optional components to None
@@ -155,7 +154,6 @@ def test_save_load_optional_components(self):
155154
"generator": generator,
156155
"num_inference_steps": num_inference_steps,
157156
"output_type": output_type,
158-
"mask_feature": False,
159157
}
160158

161159
output_loaded = pipe_loaded(**inputs)[0]

0 commit comments

Comments
 (0)