Skip to content

Commit b8fe8d4

Browse files
committed
[modular] Refactor pipeline functions
1 parent 12650e1 commit b8fe8d4

File tree

1 file changed

+43
-162
lines changed

1 file changed

+43
-162
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 43 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
221221
ip_adapter_image=data.ip_adapter_image,
222222
ip_adapter_image_embeds=None,
223223
device=data.device,
224-
num_images_per_prompt=1,
225-
do_classifier_free_guidance=data.do_classifier_free_guidance,
226224
)
225+
227226
if data.do_classifier_free_guidance:
228-
data.negative_ip_adapter_embeds = []
229-
for i, image_embeds in enumerate(data.ip_adapter_embeds):
230-
negative_image_embeds, image_embeds = image_embeds.chunk(2)
231-
data.negative_ip_adapter_embeds.append(negative_image_embeds)
232-
data.ip_adapter_embeds[i] = image_embeds
227+
output_hidden_states = [not isinstance(image_proj_layer, ImageProjection) for image_proj_layer in pipeline.unet.encoder_hid_proj.image_projection_layers]
228+
negative_ip_adapter_embeds = []
229+
for (idx, output_hidden_state), ip_adapter_embeds in zip(enumerate(output_hidden_states), data.ip_adapter_embeds):
230+
if not output_hidden_state:
231+
negative_ip_adapter_embed = torch.zeros_like(ip_adapter_embeds)
232+
else:
233+
ip_adapter_image = data.ip_adapter_image[idx] if isinstance(data.ip_adapter_image, list) else data.ip_adapter_image
234+
ip_adapter_image = pipeline.feature_extractor(ip_adapter_image, return_tensors="pt").pixel_values
235+
negative_ip_adapter_embed = pipeline.prepare_ip_adapter_image_embeds(
236+
ip_adapter_image=ip_adapter_image,
237+
ip_adapter_image_embeds=None,
238+
device=data.device,
239+
)
240+
negative_ip_adapter_embeds.append(negative_ip_adapter_embed)
241+
data.negative_ip_adapter_embeds = negative_ip_adapter_embeds
233242

234243
self.add_block_state(state, data)
235244
return pipeline, state
@@ -333,24 +342,33 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
333342
)
334343
(
335344
data.prompt_embeds,
336-
data.negative_prompt_embeds,
337345
data.pooled_prompt_embeds,
338-
data.negative_pooled_prompt_embeds,
339346
) = pipeline.encode_prompt(
340347
data.prompt,
341348
data.prompt_2,
342349
data.device,
343-
1,
344-
data.do_classifier_free_guidance,
345-
data.negative_prompt,
346-
data.negative_prompt_2,
347350
prompt_embeds=None,
348-
negative_prompt_embeds=None,
349351
pooled_prompt_embeds=None,
350-
negative_pooled_prompt_embeds=None,
351352
lora_scale=data.text_encoder_lora_scale,
352353
clip_skip=data.clip_skip,
353354
)
355+
zero_out_negative_prompt = data.negative_prompt is None and self.configs.get('force_zeros_for_empty_prompt', False)
356+
if data.do_classifier_free_guidance and zero_out_negative_prompt:
357+
data.negative_prompt_embeds = torch.zeros_like(data.prompt_embeds)
358+
data.negative_pooled_prompt_embeds = torch.zeros_like(data.pooled_prompt_embeds)
359+
elif data.do_classifier_free_guidance and not zero_out_negative_prompt:
360+
(
361+
data.negative_prompt_embeds,
362+
data.negative_pooled_prompt_embeds,
363+
) = pipeline.encode_prompt(
364+
data.negative_prompt,
365+
data.negative_prompt_2,
366+
data.device,
367+
prompt_embeds=None,
368+
pooled_prompt_embeds=None,
369+
lora_scale=data.text_encoder_lora_scale,
370+
clip_skip=data.clip_skip,
371+
)
354372
# Add outputs
355373
self.add_block_state(state, data)
356374
return pipeline, state
@@ -3197,8 +3215,7 @@ def _get_add_time_ids_img2img(
31973215

31983216
return add_time_ids, add_neg_time_ids
31993217

3200-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
3201-
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
3218+
def encode_image(self, image, device, output_hidden_states=None):
32023219
dtype = next(self.image_encoder.parameters()).dtype
32033220

32043221
if not isinstance(image, torch.Tensor):
@@ -3207,20 +3224,10 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
32073224
image = image.to(device=device, dtype=dtype)
32083225
if output_hidden_states:
32093226
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
3210-
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
3211-
uncond_image_enc_hidden_states = self.image_encoder(
3212-
torch.zeros_like(image), output_hidden_states=True
3213-
).hidden_states[-2]
3214-
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
3215-
num_images_per_prompt, dim=0
3216-
)
3217-
return image_enc_hidden_states, uncond_image_enc_hidden_states
3227+
return image_enc_hidden_states
32183228
else:
32193229
image_embeds = self.image_encoder(image).image_embeds
3220-
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
3221-
uncond_image_embeds = torch.zeros_like(image_embeds)
3222-
3223-
return image_embeds, uncond_image_embeds
3230+
return image_embeds
32243231

32253232
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
32263233
# 1. return image without apply any guidance
@@ -3254,20 +3261,13 @@ def prepare_control_image(
32543261

32553262
return image
32563263

3257-
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
32583264
def encode_prompt(
32593265
self,
32603266
prompt: str,
32613267
prompt_2: Optional[str] = None,
32623268
device: Optional[torch.device] = None,
3263-
num_images_per_prompt: int = 1,
3264-
do_classifier_free_guidance: bool = True,
3265-
negative_prompt: Optional[str] = None,
3266-
negative_prompt_2: Optional[str] = None,
32673269
prompt_embeds: Optional[torch.Tensor] = None,
3268-
negative_prompt_embeds: Optional[torch.Tensor] = None,
32693270
pooled_prompt_embeds: Optional[torch.Tensor] = None,
3270-
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
32713271
lora_scale: Optional[float] = None,
32723272
clip_skip: Optional[int] = None,
32733273
):
@@ -3282,31 +3282,12 @@ def encode_prompt(
32823282
used in both text-encoders
32833283
device: (`torch.device`):
32843284
torch device
3285-
num_images_per_prompt (`int`):
3286-
number of images that should be generated per prompt
3287-
do_classifier_free_guidance (`bool`):
3288-
whether to use classifier free guidance or not
3289-
negative_prompt (`str` or `List[str]`, *optional*):
3290-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
3291-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
3292-
less than `1`).
3293-
negative_prompt_2 (`str` or `List[str]`, *optional*):
3294-
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
3295-
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
32963285
prompt_embeds (`torch.Tensor`, *optional*):
32973286
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
32983287
provided, text embeddings will be generated from `prompt` input argument.
3299-
negative_prompt_embeds (`torch.Tensor`, *optional*):
3300-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
3301-
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
3302-
argument.
33033288
pooled_prompt_embeds (`torch.Tensor`, *optional*):
33043289
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
33053290
If not provided, pooled text embeddings will be generated from `prompt` input argument.
3306-
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
3307-
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
3308-
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
3309-
input argument.
33103291
lora_scale (`float`, *optional*):
33113292
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
33123293
clip_skip (`int`, *optional*):
@@ -3391,92 +3372,11 @@ def encode_prompt(
33913372

33923373
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
33933374

3394-
# get unconditional embeddings for classifier free guidance
3395-
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
3396-
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
3397-
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
3398-
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
3399-
elif do_classifier_free_guidance and negative_prompt_embeds is None:
3400-
negative_prompt = negative_prompt or ""
3401-
negative_prompt_2 = negative_prompt_2 or negative_prompt
3402-
3403-
# normalize str to list
3404-
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
3405-
negative_prompt_2 = (
3406-
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
3407-
)
3408-
3409-
uncond_tokens: List[str]
3410-
if prompt is not None and type(prompt) is not type(negative_prompt):
3411-
raise TypeError(
3412-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
3413-
f" {type(prompt)}."
3414-
)
3415-
elif batch_size != len(negative_prompt):
3416-
raise ValueError(
3417-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
3418-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
3419-
" the batch size of `prompt`."
3420-
)
3421-
else:
3422-
uncond_tokens = [negative_prompt, negative_prompt_2]
3423-
3424-
negative_prompt_embeds_list = []
3425-
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
3426-
if isinstance(self, TextualInversionLoaderMixin):
3427-
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
3428-
3429-
max_length = prompt_embeds.shape[1]
3430-
uncond_input = tokenizer(
3431-
negative_prompt,
3432-
padding="max_length",
3433-
max_length=max_length,
3434-
truncation=True,
3435-
return_tensors="pt",
3436-
)
3437-
3438-
negative_prompt_embeds = text_encoder(
3439-
uncond_input.input_ids.to(device),
3440-
output_hidden_states=True,
3441-
)
3442-
# We are only ALWAYS interested in the pooled output of the final text encoder
3443-
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
3444-
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
3445-
3446-
negative_prompt_embeds_list.append(negative_prompt_embeds)
3447-
3448-
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
3449-
34503375
if self.text_encoder_2 is not None:
34513376
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
34523377
else:
34533378
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
34543379

3455-
bs_embed, seq_len, _ = prompt_embeds.shape
3456-
# duplicate text embeddings for each generation per prompt, using mps friendly method
3457-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
3458-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
3459-
3460-
if do_classifier_free_guidance:
3461-
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
3462-
seq_len = negative_prompt_embeds.shape[1]
3463-
3464-
if self.text_encoder_2 is not None:
3465-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
3466-
else:
3467-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
3468-
3469-
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
3470-
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
3471-
3472-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
3473-
bs_embed * num_images_per_prompt, -1
3474-
)
3475-
if do_classifier_free_guidance:
3476-
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
3477-
bs_embed * num_images_per_prompt, -1
3478-
)
3479-
34803380
if self.text_encoder is not None:
34813381
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
34823382
# Retrieve the original scale by scaling back the LoRA layers
@@ -3487,16 +3387,13 @@ def encode_prompt(
34873387
# Retrieve the original scale by scaling back the LoRA layers
34883388
unscale_lora_layers(self.text_encoder_2, lora_scale)
34893389

3490-
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
3390+
return prompt_embeds, pooled_prompt_embeds
34913391

3492-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
34933392
def prepare_ip_adapter_image_embeds(
3494-
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
3393+
self, ip_adapter_image, ip_adapter_image_embeds, device
34953394
):
3496-
image_embeds = []
3497-
if do_classifier_free_guidance:
3498-
negative_image_embeds = []
34993395
if ip_adapter_image_embeds is None:
3396+
ip_adapter_image_embeds = []
35003397
if not isinstance(ip_adapter_image, list):
35013398
ip_adapter_image = [ip_adapter_image]
35023399

@@ -3509,29 +3406,13 @@ def prepare_ip_adapter_image_embeds(
35093406
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
35103407
):
35113408
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
3512-
single_image_embeds, single_negative_image_embeds = self.encode_image(
3513-
single_ip_adapter_image, device, 1, output_hidden_state
3409+
single_image_embeds = self.encode_image(
3410+
single_ip_adapter_image, device, output_hidden_state
35143411
)
3412+
ip_adapter_image_embeds.append(single_image_embeds[None, :])
35153413

3516-
image_embeds.append(single_image_embeds[None, :])
3517-
if do_classifier_free_guidance:
3518-
negative_image_embeds.append(single_negative_image_embeds[None, :])
3519-
else:
3520-
for single_image_embeds in ip_adapter_image_embeds:
3521-
if do_classifier_free_guidance:
3522-
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
3523-
negative_image_embeds.append(single_negative_image_embeds)
3524-
image_embeds.append(single_image_embeds)
3525-
3526-
ip_adapter_image_embeds = []
3527-
for i, single_image_embeds in enumerate(image_embeds):
3528-
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
3529-
if do_classifier_free_guidance:
3530-
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
3531-
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
3532-
3414+
for single_image_embeds in ip_adapter_image_embeds:
35333415
single_image_embeds = single_image_embeds.to(device=device)
3534-
ip_adapter_image_embeds.append(single_image_embeds)
35353416

35363417
return ip_adapter_image_embeds
35373418

0 commit comments

Comments
 (0)