Skip to content

Commit 5b3295a

Browse files
committed
apply to flux
1 parent 78f292e commit 5b3295a

13 files changed

+34
-3261
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 2 additions & 287 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,13 @@
3131
from ...models import AutoencoderKL, FluxTransformer2DModel
3232
from ...schedulers import FlowMatchEulerDiscreteScheduler
3333
from ...utils import (
34-
USE_PEFT_BACKEND,
35-
deprecate,
3634
is_torch_xla_available,
3735
logging,
3836
replace_example_docstring,
39-
scale_lora_layers,
40-
unscale_lora_layers,
4137
)
4238
from ...utils.torch_utils import randn_tensor
4339
from ..pipeline_utils import DiffusionPipeline
40+
from .pipeline_flux_utils import FluxMixin
4441
from .pipeline_output import FluxPipelineOutput
4542

4643

@@ -146,6 +143,7 @@ def retrieve_timesteps(
146143

147144
class FluxPipeline(
148145
DiffusionPipeline,
146+
FluxMixin,
149147
FluxLoraLoaderMixin,
150148
FromSingleFileMixin,
151149
TextualInversionLoaderMixin,
@@ -215,178 +213,6 @@ def __init__(
215213
)
216214
self.default_sample_size = 128
217215

218-
def _get_t5_prompt_embeds(
219-
self,
220-
prompt: Union[str, List[str]] = None,
221-
num_images_per_prompt: int = 1,
222-
max_sequence_length: int = 512,
223-
device: Optional[torch.device] = None,
224-
dtype: Optional[torch.dtype] = None,
225-
):
226-
device = device or self._execution_device
227-
dtype = dtype or self.text_encoder.dtype
228-
229-
prompt = [prompt] if isinstance(prompt, str) else prompt
230-
batch_size = len(prompt)
231-
232-
if isinstance(self, TextualInversionLoaderMixin):
233-
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
234-
235-
text_inputs = self.tokenizer_2(
236-
prompt,
237-
padding="max_length",
238-
max_length=max_sequence_length,
239-
truncation=True,
240-
return_length=False,
241-
return_overflowing_tokens=False,
242-
return_tensors="pt",
243-
)
244-
text_input_ids = text_inputs.input_ids
245-
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
246-
247-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
248-
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
249-
logger.warning(
250-
"The following part of your input was truncated because `max_sequence_length` is set to "
251-
f" {max_sequence_length} tokens: {removed_text}"
252-
)
253-
254-
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
255-
256-
dtype = self.text_encoder_2.dtype
257-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
258-
259-
_, seq_len, _ = prompt_embeds.shape
260-
261-
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
262-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
263-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
264-
265-
return prompt_embeds
266-
267-
def _get_clip_prompt_embeds(
268-
self,
269-
prompt: Union[str, List[str]],
270-
num_images_per_prompt: int = 1,
271-
device: Optional[torch.device] = None,
272-
):
273-
device = device or self._execution_device
274-
275-
prompt = [prompt] if isinstance(prompt, str) else prompt
276-
batch_size = len(prompt)
277-
278-
if isinstance(self, TextualInversionLoaderMixin):
279-
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
280-
281-
text_inputs = self.tokenizer(
282-
prompt,
283-
padding="max_length",
284-
max_length=self.tokenizer_max_length,
285-
truncation=True,
286-
return_overflowing_tokens=False,
287-
return_length=False,
288-
return_tensors="pt",
289-
)
290-
291-
text_input_ids = text_inputs.input_ids
292-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
293-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
295-
logger.warning(
296-
"The following part of your input was truncated because CLIP can only handle sequences up to"
297-
f" {self.tokenizer_max_length} tokens: {removed_text}"
298-
)
299-
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
300-
301-
# Use pooled output of CLIPTextModel
302-
prompt_embeds = prompt_embeds.pooler_output
303-
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
304-
305-
# duplicate text embeddings for each generation per prompt, using mps friendly method
306-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
307-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
308-
309-
return prompt_embeds
310-
311-
def encode_prompt(
312-
self,
313-
prompt: Union[str, List[str]],
314-
prompt_2: Optional[Union[str, List[str]]] = None,
315-
device: Optional[torch.device] = None,
316-
num_images_per_prompt: int = 1,
317-
prompt_embeds: Optional[torch.FloatTensor] = None,
318-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
319-
max_sequence_length: int = 512,
320-
lora_scale: Optional[float] = None,
321-
):
322-
r"""
323-
324-
Args:
325-
prompt (`str` or `List[str]`, *optional*):
326-
prompt to be encoded
327-
prompt_2 (`str` or `List[str]`, *optional*):
328-
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
329-
used in all text-encoders
330-
device: (`torch.device`):
331-
torch device
332-
num_images_per_prompt (`int`):
333-
number of images that should be generated per prompt
334-
prompt_embeds (`torch.FloatTensor`, *optional*):
335-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
336-
provided, text embeddings will be generated from `prompt` input argument.
337-
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
338-
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
339-
If not provided, pooled text embeddings will be generated from `prompt` input argument.
340-
lora_scale (`float`, *optional*):
341-
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
342-
"""
343-
device = device or self._execution_device
344-
345-
# set lora scale so that monkey patched LoRA
346-
# function of text encoder can correctly access it
347-
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
348-
self._lora_scale = lora_scale
349-
350-
# dynamically adjust the LoRA scale
351-
if self.text_encoder is not None and USE_PEFT_BACKEND:
352-
scale_lora_layers(self.text_encoder, lora_scale)
353-
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
354-
scale_lora_layers(self.text_encoder_2, lora_scale)
355-
356-
prompt = [prompt] if isinstance(prompt, str) else prompt
357-
358-
if prompt_embeds is None:
359-
prompt_2 = prompt_2 or prompt
360-
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
361-
362-
# We only use the pooled prompt output from the CLIPTextModel
363-
pooled_prompt_embeds = self._get_clip_prompt_embeds(
364-
prompt=prompt,
365-
device=device,
366-
num_images_per_prompt=num_images_per_prompt,
367-
)
368-
prompt_embeds = self._get_t5_prompt_embeds(
369-
prompt=prompt_2,
370-
num_images_per_prompt=num_images_per_prompt,
371-
max_sequence_length=max_sequence_length,
372-
device=device,
373-
)
374-
375-
if self.text_encoder is not None:
376-
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
377-
# Retrieve the original scale by scaling back the LoRA layers
378-
unscale_lora_layers(self.text_encoder, lora_scale)
379-
380-
if self.text_encoder_2 is not None:
381-
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
382-
# Retrieve the original scale by scaling back the LoRA layers
383-
unscale_lora_layers(self.text_encoder_2, lora_scale)
384-
385-
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
386-
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
387-
388-
return prompt_embeds, pooled_prompt_embeds, text_ids
389-
390216
def encode_image(self, image, device, num_images_per_prompt):
391217
dtype = next(self.image_encoder.parameters()).dtype
392218

@@ -503,97 +329,6 @@ def check_inputs(
503329
if max_sequence_length is not None and max_sequence_length > 512:
504330
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
505331

506-
@staticmethod
507-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
508-
latent_image_ids = torch.zeros(height, width, 3)
509-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
510-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
511-
512-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
513-
514-
latent_image_ids = latent_image_ids.reshape(
515-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
516-
)
517-
518-
return latent_image_ids.to(device=device, dtype=dtype)
519-
520-
@staticmethod
521-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
522-
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
523-
latents = latents.permute(0, 2, 4, 1, 3, 5)
524-
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
525-
526-
return latents
527-
528-
@staticmethod
529-
def _unpack_latents(latents, height, width, vae_scale_factor):
530-
batch_size, num_patches, channels = latents.shape
531-
532-
# VAE applies 8x compression on images but we must also account for packing which requires
533-
# latent height and width to be divisible by 2.
534-
height = 2 * (int(height) // (vae_scale_factor * 2))
535-
width = 2 * (int(width) // (vae_scale_factor * 2))
536-
537-
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
538-
latents = latents.permute(0, 3, 1, 4, 2, 5)
539-
540-
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
541-
542-
return latents
543-
544-
def enable_vae_slicing(self):
545-
r"""
546-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
547-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
548-
"""
549-
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
550-
deprecate(
551-
"enable_vae_slicing",
552-
"0.40.0",
553-
depr_message,
554-
)
555-
self.vae.enable_slicing()
556-
557-
def disable_vae_slicing(self):
558-
r"""
559-
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
560-
computing decoding in one step.
561-
"""
562-
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
563-
deprecate(
564-
"disable_vae_slicing",
565-
"0.40.0",
566-
depr_message,
567-
)
568-
self.vae.disable_slicing()
569-
570-
def enable_vae_tiling(self):
571-
r"""
572-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
573-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
574-
processing larger images.
575-
"""
576-
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
577-
deprecate(
578-
"enable_vae_tiling",
579-
"0.40.0",
580-
depr_message,
581-
)
582-
self.vae.enable_tiling()
583-
584-
def disable_vae_tiling(self):
585-
r"""
586-
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
587-
computing decoding in one step.
588-
"""
589-
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
590-
deprecate(
591-
"disable_vae_tiling",
592-
"0.40.0",
593-
depr_message,
594-
)
595-
self.vae.disable_tiling()
596-
597332
def prepare_latents(
598333
self,
599334
batch_size,
@@ -629,26 +364,6 @@ def prepare_latents(
629364

630365
return latents, latent_image_ids
631366

632-
@property
633-
def guidance_scale(self):
634-
return self._guidance_scale
635-
636-
@property
637-
def joint_attention_kwargs(self):
638-
return self._joint_attention_kwargs
639-
640-
@property
641-
def num_timesteps(self):
642-
return self._num_timesteps
643-
644-
@property
645-
def current_timestep(self):
646-
return self._current_timestep
647-
648-
@property
649-
def interrupt(self):
650-
return self._interrupt
651-
652367
@torch.no_grad()
653368
@replace_example_docstring(EXAMPLE_DOC_STRING)
654369
def __call__(

0 commit comments

Comments
 (0)