Skip to content

Commit 02edb0f

Browse files
committed
separate encode_prompt, add copied from, image_encoder offload
1 parent 248bbd4 commit 02edb0f

11 files changed

+30
-65
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 20 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ class FluxPipeline(
177177
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
178178
"""
179179

180-
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
180+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
181181
_optional_components = ["image_encoder", "feature_extractor"]
182182
_callback_tensor_inputs = ["latents", "prompt_embeds"]
183183

@@ -314,17 +314,12 @@ def encode_prompt(
314314
self,
315315
prompt: Union[str, List[str]],
316316
prompt_2: Union[str, List[str]],
317-
negative_prompt: Union[str, List[str]] = None,
318-
negative_prompt_2: Union[str, List[str]] = None,
319317
device: Optional[torch.device] = None,
320318
num_images_per_prompt: int = 1,
321319
prompt_embeds: Optional[torch.FloatTensor] = None,
322320
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
323-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
324-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
325321
max_sequence_length: int = 512,
326322
lora_scale: Optional[float] = None,
327-
do_true_cfg: bool = False,
328323
):
329324
r"""
330325
@@ -361,62 +356,24 @@ def encode_prompt(
361356
scale_lora_layers(self.text_encoder_2, lora_scale)
362357

363358
prompt = [prompt] if isinstance(prompt, str) else prompt
364-
if prompt is not None:
365-
batch_size = len(prompt)
366-
else:
367-
batch_size = prompt_embeds.shape[0]
368-
369-
if do_true_cfg and negative_prompt is not None:
370-
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
371-
negative_batch_size = len(negative_prompt)
372-
373-
if negative_batch_size != batch_size:
374-
raise ValueError(
375-
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
376-
)
377-
378-
# Concatenate prompts
379-
prompts = prompt + negative_prompt
380-
prompts_2 = (
381-
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
382-
)
383-
else:
384-
prompts = prompt
385-
prompts_2 = prompt_2
386359

387360
if prompt_embeds is None:
388-
if prompts_2 is None:
389-
prompts_2 = prompts
361+
prompt_2 = prompt_2 or prompt
362+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
390363

391364
# We only use the pooled prompt output from the CLIPTextModel
392365
pooled_prompt_embeds = self._get_clip_prompt_embeds(
393-
prompt=prompts,
366+
prompt=prompt,
394367
device=device,
395368
num_images_per_prompt=num_images_per_prompt,
396369
)
397370
prompt_embeds = self._get_t5_prompt_embeds(
398-
prompt=prompts_2,
371+
prompt=prompt_2,
399372
num_images_per_prompt=num_images_per_prompt,
400373
max_sequence_length=max_sequence_length,
401374
device=device,
402375
)
403376

404-
if do_true_cfg and negative_prompt is not None:
405-
# Split embeddings back into positive and negative parts
406-
total_batch_size = batch_size * num_images_per_prompt
407-
positive_indices = slice(0, total_batch_size)
408-
negative_indices = slice(total_batch_size, 2 * total_batch_size)
409-
410-
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
411-
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
412-
413-
positive_prompt_embeds = prompt_embeds[positive_indices]
414-
negative_prompt_embeds = prompt_embeds[negative_indices]
415-
416-
pooled_prompt_embeds = positive_pooled_prompt_embeds
417-
prompt_embeds = positive_prompt_embeds
418-
419-
# Unscale LoRA layers
420377
if self.text_encoder is not None:
421378
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
422379
# Retrieve the original scale by scaling back the LoRA layers
@@ -430,16 +387,7 @@ def encode_prompt(
430387
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
431388
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
432389

433-
if do_true_cfg and negative_prompt is not None:
434-
return (
435-
prompt_embeds,
436-
pooled_prompt_embeds,
437-
text_ids,
438-
negative_prompt_embeds,
439-
negative_pooled_prompt_embeds,
440-
)
441-
else:
442-
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
390+
return prompt_embeds, pooled_prompt_embeds, text_ids
443391

444392
def encode_image(self, image, device, num_images_per_prompt):
445393
dtype = next(self.image_encoder.parameters()).dtype
@@ -832,22 +780,29 @@ def __call__(
832780
prompt_embeds,
833781
pooled_prompt_embeds,
834782
text_ids,
835-
negative_prompt_embeds,
836-
negative_pooled_prompt_embeds,
837783
) = self.encode_prompt(
838784
prompt=prompt,
839785
prompt_2=prompt_2,
840-
negative_prompt=negative_prompt,
841-
negative_prompt_2=negative_prompt_2,
842786
prompt_embeds=prompt_embeds,
843787
pooled_prompt_embeds=pooled_prompt_embeds,
844-
negative_prompt_embeds=negative_prompt_embeds,
845-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
846788
device=device,
847789
num_images_per_prompt=num_images_per_prompt,
848790
max_sequence_length=max_sequence_length,
849791
lora_scale=lora_scale,
850-
do_true_cfg=do_true_cfg,
792+
)
793+
(
794+
negative_prompt_embeds,
795+
negative_pooled_prompt_embeds,
796+
_,
797+
) = self.encode_prompt(
798+
prompt=negative_prompt,
799+
prompt_2=negative_prompt_2,
800+
prompt_embeds=negative_prompt_embeds,
801+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
802+
device=device,
803+
num_images_per_prompt=num_images_per_prompt,
804+
max_sequence_length=max_sequence_length,
805+
lora_scale=lora_scale,
851806
)
852807

853808
# 4. Prepare latent variables

src/diffusers/pipelines/flux/pipeline_flux_control.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def _get_clip_prompt_embeds(
323323

324324
return prompt_embeds
325325

326+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
326327
def encode_prompt(
327328
self,
328329
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def _get_clip_prompt_embeds(
333333

334334
return prompt_embeds
335335

336+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
336337
def encode_prompt(
337338
self,
338339
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def _get_clip_prompt_embeds(
371371

372372
return prompt_embeds
373373

374+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
374375
def encode_prompt(
375376
self,
376377
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def _get_clip_prompt_embeds(
333333

334334
return prompt_embeds
335335

336+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
336337
def encode_prompt(
337338
self,
338339
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def _get_clip_prompt_embeds(
333333

334334
return prompt_embeds
335335

336+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
336337
def encode_prompt(
337338
self,
338339
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def _get_clip_prompt_embeds(
343343

344344
return prompt_embeds
345345

346+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
346347
def encode_prompt(
347348
self,
348349
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_fill.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def prepare_mask_latents(
414414

415415
return mask, masked_image_latents
416416

417+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
417418
def encode_prompt(
418419
self,
419420
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def _get_clip_prompt_embeds(
317317

318318
return prompt_embeds
319319

320+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
320321
def encode_prompt(
321322
self,
322323
prompt: Union[str, List[str]],

src/diffusers/pipelines/flux/pipeline_flux_inpaint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def _get_clip_prompt_embeds(
321321

322322
return prompt_embeds
323323

324+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
324325
def encode_prompt(
325326
self,
326327
prompt: Union[str, List[str]],

0 commit comments

Comments
 (0)