2121from ...models import AuraFlowTransformer2DModel , AutoencoderKL
2222from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
2323from ...schedulers import FlowMatchEulerDiscreteScheduler
24- from ...utils import logging
24+ from ...utils import logging , replace_example_docstring
2525from ...utils .torch_utils import randn_tensor
2626from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
2727
2828
2929logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3030
3131
32+ EXAMPLE_DOC_STRING = """
33+ Examples:
34+ ```py
35+ >>> import torch
36+ >>> from diffusers import AuraFlowPipeline
37+
38+ >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
39+ >>> pipe = pipe.to("cuda")
40+ >>> prompt = "A cat holding a sign that says hello world"
41+ >>> image = pipe(prompt).images[0]
42+ >>> image.save("aura_flow.png")
43+ ```
44+ """
45+
46+
3247# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
3348def retrieve_timesteps (
3449 scheduler ,
@@ -90,6 +105,23 @@ def retrieve_timesteps(
90105
91106
92107class AuraFlowPipeline (DiffusionPipeline ):
108+ r"""
109+ Args:
110+ tokenizer (`T5TokenizerFast`):
111+ Tokenizer of class
112+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
113+ text_encoder ([`T5EncoderModel`]):
114+ Frozen text-encoder. AuraFlow uses
115+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
116+ [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
117+ vae ([`AutoencoderKL`]):
118+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
119+ transformer ([`AuraFlowTransformer2DModel`]):
120+ Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
121+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
122+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
123+ """
124+
93125 _optional_components = []
94126 model_cpu_offload_seq = "text_encoder->transformer->vae"
95127
@@ -201,8 +233,12 @@ def encode_prompt(
201233 prompt_embeds (`torch.Tensor`, *optional*):
202234 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
203235 provided, text embeddings will be generated from `prompt` input argument.
236+ prompt_attention_mask (`torch.Tensor`, *optional*):
237+ Pre-generated attention mask for text embeddings.
204238 negative_prompt_embeds (`torch.Tensor`, *optional*):
205239 Pre-generated negative text embeddings.
240+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
241+ Pre-generated attention mask for negative text embeddings.
206242 max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
207243 """
208244 if device is None :
@@ -345,6 +381,7 @@ def upcast_vae(self):
345381 self .vae .decoder .mid_block .to (dtype )
346382
347383 @torch .no_grad ()
384+ @replace_example_docstring (EXAMPLE_DOC_STRING )
348385 def __call__ (
349386 self ,
350387 prompt : Union [str , List [str ]] = None ,
@@ -366,6 +403,71 @@ def __call__(
366403 output_type : Optional [str ] = "pil" ,
367404 return_dict : bool = True ,
368405 ) -> Union [ImagePipelineOutput , Tuple ]:
406+ r"""
407+ Function invoked when calling the pipeline for generation.
408+
409+ Args:
410+ prompt (`str` or `List[str]`, *optional*):
411+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
412+ instead.
413+ negative_prompt (`str` or `List[str]`, *optional*):
414+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
415+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
416+ less than `1`).
417+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
418+ The height in pixels of the generated image. This is set to 512 by default.
419+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
420+ The width in pixels of the generated image. This is set to 512 by default.
421+ num_inference_steps (`int`, *optional*, defaults to 50):
422+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
423+ expense of slower inference.
424+ sigmas (`List[float]`, *optional*):
425+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
426+ `num_inference_steps` and `timesteps` must be `None`.
427+ timesteps (`List[int]`, *optional*):
428+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
429+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
430+ passed will be used. Must be in descending order.
431+ guidance_scale (`float`, *optional*, defaults to 5.0):
432+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
434+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
435+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
436+ usually at the expense of lower image quality.
437+ num_images_per_prompt (`int`, *optional*, defaults to 1):
438+ The number of images to generate per prompt.
439+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
440+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
441+ to make generation deterministic.
442+ latents (`torch.FloatTensor`, *optional*):
443+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
444+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
445+ tensor will ge generated by sampling using the supplied random `generator`.
446+ prompt_embeds (`torch.FloatTensor`, *optional*):
447+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
448+ provided, text embeddings will be generated from `prompt` input argument.
449+ prompt_attention_mask (`torch.Tensor`, *optional*):
450+ Pre-generated attention mask for text embeddings.
451+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
452+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
453+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
454+ argument.
455+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
456+ Pre-generated attention mask for negative text embeddings.
457+ output_type (`str`, *optional*, defaults to `"pil"`):
458+ The output format of the generate image. Choose between
459+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
460+ return_dict (`bool`, *optional*, defaults to `True`):
461+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
462+ of a plain tuple.
463+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
464+
465+ Examples:
466+
467+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
468+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
469+ where the first element is a list with the generated images.
470+ """
369471 # 1. Check inputs. Raise error if not correct
370472 height = height or self .transformer .config .sample_size * self .vae_scale_factor
371473 width = width or self .transformer .config .sample_size * self .vae_scale_factor
0 commit comments