Skip to content

Commit 9c74ac3

Browse files
committed
encode_single_prompt
1 parent b8fe8d4 commit 9c74ac3

File tree

1 file changed

+55
-17
lines changed

1 file changed

+55
-17
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,33 +342,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
342342
)
343343
(
344344
data.prompt_embeds,
345+
data.negative_prompt_embeds,
345346
data.pooled_prompt_embeds,
347+
data.negative_pooled_prompt_embeds,
346348
) = pipeline.encode_prompt(
347349
data.prompt,
348350
data.prompt_2,
349351
data.device,
352+
data.do_classifier_free_guidance,
353+
data.negative_prompt,
354+
data.negative_prompt_2,
350355
prompt_embeds=None,
356+
negative_prompt_embeds=None,
351357
pooled_prompt_embeds=None,
358+
negative_pooled_prompt_embeds=None,
352359
lora_scale=data.text_encoder_lora_scale,
353360
clip_skip=data.clip_skip,
361+
force_zeros_for_empty_prompt=self.configs.get('force_zeros_for_empty_prompt', False),
354362
)
355-
zero_out_negative_prompt = data.negative_prompt is None and self.configs.get('force_zeros_for_empty_prompt', False)
356-
if data.do_classifier_free_guidance and zero_out_negative_prompt:
357-
data.negative_prompt_embeds = torch.zeros_like(data.prompt_embeds)
358-
data.negative_pooled_prompt_embeds = torch.zeros_like(data.pooled_prompt_embeds)
359-
elif data.do_classifier_free_guidance and not zero_out_negative_prompt:
360-
(
361-
data.negative_prompt_embeds,
362-
data.negative_pooled_prompt_embeds,
363-
) = pipeline.encode_prompt(
364-
data.negative_prompt,
365-
data.negative_prompt_2,
366-
data.device,
367-
prompt_embeds=None,
368-
pooled_prompt_embeds=None,
369-
lora_scale=data.text_encoder_lora_scale,
370-
clip_skip=data.clip_skip,
371-
)
372363
# Add outputs
373364
self.add_block_state(state, data)
374365
return pipeline, state
@@ -3262,6 +3253,53 @@ def prepare_control_image(
32623253
return image
32633254

32643255
def encode_prompt(
3256+
self,
3257+
prompt: str,
3258+
prompt_2: Optional[str] = None,
3259+
device: Optional[torch.device] = None,
3260+
do_classifier_free_guidance: bool = True,
3261+
negative_prompt: Optional[str] = None,
3262+
negative_prompt_2: Optional[str] = None,
3263+
prompt_embeds: Optional[torch.Tensor] = None,
3264+
negative_prompt_embeds: Optional[torch.Tensor] = None,
3265+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
3266+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
3267+
lora_scale: Optional[float] = None,
3268+
clip_skip: Optional[int] = None,
3269+
force_zeros_for_empty_prompt: bool = False,
3270+
):
3271+
(
3272+
prompt_embeds,
3273+
pooled_prompt_embeds,
3274+
) = self.encode_single_prompt(
3275+
prompt,
3276+
prompt_2,
3277+
device,
3278+
prompt_embeds=prompt_embeds,
3279+
pooled_prompt_embeds=pooled_prompt_embeds,
3280+
lora_scale=lora_scale,
3281+
clip_skip=clip_skip,
3282+
)
3283+
zero_out_negative_prompt = negative_prompt is None and force_zeros_for_empty_prompt
3284+
if do_classifier_free_guidance and zero_out_negative_prompt:
3285+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
3286+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
3287+
elif do_classifier_free_guidance and not zero_out_negative_prompt:
3288+
(
3289+
negative_prompt_embeds,
3290+
negative_pooled_prompt_embeds,
3291+
) = self.encode_single_prompt(
3292+
negative_prompt,
3293+
negative_prompt_2,
3294+
device,
3295+
prompt_embeds=negative_prompt_embeds,
3296+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
3297+
lora_scale=lora_scale,
3298+
clip_skip=clip_skip,
3299+
)
3300+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
3301+
3302+
def encode_single_prompt(
32653303
self,
32663304
prompt: str,
32673305
prompt_2: Optional[str] = None,

0 commit comments

Comments
 (0)