Skip to content

Commit 27021ac

Browse files
committed
add doc for redux
1 parent 201d8dc commit 27021ac

File tree

2 files changed

+93
-23
lines changed

2 files changed

+93
-23
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,41 @@ image = pipe(
172172
image.save("output.png")
173173
```
174174

175+
### Redux
176+
177+
* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
178+
* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.
179+
* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.
180+
181+
```python
182+
import torch
183+
from diffusers import FluxPriorReduxPipeline, FluxPipeline
184+
from diffusers.utils import load_image
185+
device = "cuda"
186+
dtype = torch.bfloat16
187+
188+
189+
repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
190+
repo_base = "black-forest-labs/FLUX.1-dev"
191+
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
192+
pipe = FluxPipeline.from_pretrained(
193+
repo_base,
194+
text_encoder=None,
195+
text_encoder_2=None,
196+
torch_dtype=torch.bfloat16
197+
).to(device)
198+
199+
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
200+
pipe_prior_output = pipe_prior_redux(image)
201+
images = pipe(
202+
guidance_scale=2.5,
203+
num_inference_steps=50,
204+
generator=torch.Generator("cpu").manual_seed(0),
205+
**pipe_prior_output,
206+
).images
207+
images[0].save("flux-redux.png")
208+
```
209+
175210
## Running FP16 inference
176211

177212
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.

src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,59 @@
5353
Examples:
5454
```py
5555
>>> import torch
56-
>>> from diffusers import FluxPipeline
57-
58-
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
59-
>>> pipe.to("cuda")
60-
>>> prompt = "A cat holding a sign that says hello world"
61-
>>> # Depending on the variant being used, the pipeline call will slightly vary.
62-
>>> # Refer to the pipeline documentation for more details.
63-
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
64-
>>> image.save("flux.png")
56+
>>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
57+
>>> from diffusers.utils import load_image
58+
59+
>>> device = "cuda"
60+
>>> dtype = torch.bfloat16
61+
62+
>>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
63+
>>> repo_base = "black-forest-labs/FLUX.1-dev"
64+
>>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
65+
>>> pipe = FluxPipeline.from_pretrained(
66+
... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
67+
... ).to(device)
68+
69+
>>> image = load_image(
70+
... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
71+
... )
72+
>>> pipe_prior_output = pipe_prior_redux(image)
73+
>>> images = pipe(
74+
... guidance_scale=2.5,
75+
... num_inference_steps=50,
76+
... generator=torch.Generator("cpu").manual_seed(0),
77+
... **pipe_prior_output,
78+
... ).images
79+
>>> images[0].save("flux-redux.png")
6580
```
6681
"""
6782

6883

6984
class FluxPriorReduxPipeline(DiffusionPipeline):
7085
r"""
71-
The Flux pipeline for text-to-image generation.
86+
The Flux Redux pipeline for image-to-image generation.
7287
73-
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
88+
Reference: https://blackforestlabs.ai/flux-1-tools/
7489
7590
Args:
76-
transformer ([`FluxTransformer2DModel`]):
77-
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
78-
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
79-
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
80-
vae ([`AutoencoderKL`]):
81-
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
91+
image_encoder ([`SiglipVisionModel`]):
92+
SIGLIP vision model to encode the input image.
93+
feature_extractor ([`SiglipImageProcessor`]):
94+
Image processor for preprocessing images for the SIGLIP model.
95+
image_embedder ([`ReduxImageEncoder`]):
96+
Redux image encoder to process the SIGLIP embeddings.
97+
text_encoder ([`CLIPTextModel`], *optional*):
98+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
99+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
100+
text_encoder_2 ([`T5EncoderModel`], *optional*):
101+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
102+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
103+
tokenizer (`CLIPTokenizer`, *optional*):
104+
Tokenizer of class
105+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
106+
tokenizer_2 (`T5TokenizerFast`, *optional*):
107+
Second Tokenizer of class
108+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
82109
"""
83110

84111
model_cpu_offload_seq = "image_encoder->image_embedder"
@@ -121,6 +148,7 @@ def encode_image(self, image, device, num_images_per_prompt):
121148
images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
122149
)
123150
image = image.to(device=device, dtype=dtype)
151+
124152
image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
125153
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
126154

@@ -312,16 +340,20 @@ def __call__(
312340
Function invoked when calling the pipeline for generation.
313341
314342
Args:
315-
prompt (`str` or `List[str]`, *optional*):
316-
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
317-
instead.
343+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
344+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
345+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
346+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
347+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
348+
return_dict (`bool`, *optional*, defaults to `True`):
349+
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
318350
319351
Examples:
320352
321353
Returns:
322-
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
323-
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
324-
images.
354+
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
355+
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
356+
returning a tuple, the first element is a list with the generated images.
325357
"""
326358

327359
# 2. Define call parameters
@@ -335,6 +367,7 @@ def __call__(
335367

336368
# 3. Prepare image embeddings
337369
image_latents = self.encode_image(image, device, 1)
370+
338371
image_embeds = self.image_embedder(image_latents).image_embeds
339372
image_embeds = image_embeds.to(device=device)
340373

@@ -355,7 +388,9 @@ def __call__(
355388
lora_scale=None,
356389
)
357390
else:
391+
# max_sequence_length is 512, t5 encoder hidden size is 4096
358392
prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
393+
# pooled_prompt_embeds is 768, clip text encoder hidden size
359394
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
360395

361396
# Concatenate image and text embeddings

0 commit comments

Comments
 (0)