Skip to content

Commit a592002

Browse files
committed
add
1 parent 9ea52da commit a592002

File tree

8 files changed

+247
-1
lines changed

8 files changed

+247
-1
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@
276276
"FluxImg2ImgPipeline",
277277
"FluxInpaintPipeline",
278278
"FluxPipeline",
279+
"FluxPriorReduxPipeline",
279280
"HunyuanDiTControlNetPipeline",
280281
"HunyuanDiTPAGPipeline",
281282
"HunyuanDiTPipeline",
@@ -322,6 +323,7 @@
322323
"PixArtAlphaPipeline",
323324
"PixArtSigmaPAGPipeline",
324325
"PixArtSigmaPipeline",
326+
"ReduxImageEncoder",
325327
"SemanticStableDiffusionPipeline",
326328
"ShapEImg2ImgPipeline",
327329
"ShapEPipeline",
@@ -742,6 +744,7 @@
742744
FluxImg2ImgPipeline,
743745
FluxInpaintPipeline,
744746
FluxPipeline,
747+
FluxPriorReduxPipeline,
745748
HunyuanDiTControlNetPipeline,
746749
HunyuanDiTPAGPipeline,
747750
HunyuanDiTPipeline,
@@ -788,6 +791,7 @@
788791
PixArtAlphaPipeline,
789792
PixArtSigmaPAGPipeline,
790793
PixArtSigmaPipeline,
794+
ReduxImageEncoder,
791795
SemanticStableDiffusionPipeline,
792796
ShapEImg2ImgPipeline,
793797
ShapEPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@
134134
"FluxInpaintPipeline",
135135
"FluxPipeline",
136136
"FluxFillPipeline",
137+
"FluxPriorReduxPipeline",
138+
"ReduxImageEncoder",
137139
]
138140
_import_structure["audioldm"] = ["AudioLDMPipeline"]
139141
_import_structure["audioldm2"] = [
@@ -529,6 +531,8 @@
529531
FluxImg2ImgPipeline,
530532
FluxInpaintPipeline,
531533
FluxPipeline,
534+
FluxPriorReduxPipeline,
535+
ReduxImageEncoder,
532536
)
533537
from .hunyuandit import HunyuanDiTPipeline
534538
from .i2vgen_xl import I2VGenXLPipeline

src/diffusers/pipelines/flux/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
_dummy_objects = {}
1414
_additional_imports = {}
15-
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
15+
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
1616

1717
try:
1818
if not (is_transformers_available() and is_torch_available()):
@@ -22,27 +22,31 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["modeling_flux"] = ["ReduxImageEncoder"]
2526
_import_structure["pipeline_flux"] = ["FluxPipeline"]
2627
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
2728
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
2829
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
2930
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
3031
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
3132
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
33+
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
3234
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3335
try:
3436
if not (is_transformers_available() and is_torch_available()):
3537
raise OptionalDependencyNotAvailable()
3638
except OptionalDependencyNotAvailable:
3739
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3840
else:
41+
from .modeling_flux import ReduxImageEncoder
3942
from .pipeline_flux import FluxPipeline
4043
from .pipeline_flux_controlnet import FluxControlNetPipeline
4144
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
4245
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
4346
from .pipeline_flux_fill import FluxFillPipeline
4447
from .pipeline_flux_img2img import FluxImg2ImgPipeline
4548
from .pipeline_flux_inpaint import FluxInpaintPipeline
49+
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
4650
else:
4751
import sys
4852

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from dataclasses import dataclass
17+
from typing import Optional
18+
19+
import torch
20+
import torch.nn as nn
21+
22+
from ...configuration_utils import ConfigMixin, register_to_config
23+
from ...models.modeling_utils import ModelMixin
24+
from ...utils import BaseOutput
25+
26+
27+
@dataclass
28+
class ReduxImageEncoderOutput(BaseOutput):
29+
image_embeds: Optional[torch.Tensor] = None
30+
31+
32+
class ReduxImageEncoder(ModelMixin, ConfigMixin):
33+
@register_to_config
34+
def __init__(
35+
self,
36+
redux_dim: int = 1152,
37+
txt_in_features: int = 4096,
38+
) -> None:
39+
super().__init__()
40+
41+
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
42+
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
43+
44+
def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
45+
projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
46+
47+
return ReduxImageEncoderOutput(image_embeds=projected_x)

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ def __call__(
604604
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
605605
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
606606
max_sequence_length: int = 512,
607+
image_embeds: Optional[torch.Tensor] = None,
607608
):
608609
r"""
609610
Function invoked when calling the pipeline for generation.
@@ -800,6 +801,13 @@ def __call__(
800801
else:
801802
guidance = None
802803

804+
# prepare redux
805+
if image_embeds is not None:
806+
image_embeds = image_embeds.to(device=device, dtype=prompt_embeds.dtype)
807+
img_text_ids = torch.zeros(image_embeds.shape[1], 3).to(device=device, dtype=text_ids.dtype)
808+
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
809+
text_ids = torch.cat([text_ids, img_text_ids], dim=0)
810+
803811
# 6. Denoising loop
804812
with self.progress_bar(total=num_inference_steps) as progress_bar:
805813
for i, t in enumerate(timesteps):
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
from transformers import SiglipImageProcessor, SiglipVisionModel
18+
19+
from ...image_processor import PipelineImageInput
20+
from ...utils import (
21+
is_torch_xla_available,
22+
logging,
23+
replace_example_docstring,
24+
)
25+
from ..pipeline_utils import DiffusionPipeline
26+
from .modeling_flux import ReduxImageEncoder
27+
from .pipeline_output import FluxPriorReduxPipelineOutput
28+
29+
30+
if is_torch_xla_available():
31+
XLA_AVAILABLE = True
32+
else:
33+
XLA_AVAILABLE = False
34+
35+
36+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37+
38+
EXAMPLE_DOC_STRING = """
39+
Examples:
40+
```py
41+
>>> import torch
42+
>>> from diffusers import FluxPipeline
43+
44+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
45+
>>> pipe.to("cuda")
46+
>>> prompt = "A cat holding a sign that says hello world"
47+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
48+
>>> # Refer to the pipeline documentation for more details.
49+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
50+
>>> image.save("flux.png")
51+
```
52+
"""
53+
54+
55+
class FluxPriorReduxPipeline(DiffusionPipeline):
56+
r"""
57+
The Flux pipeline for text-to-image generation.
58+
59+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
60+
61+
Args:
62+
transformer ([`FluxTransformer2DModel`]):
63+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
64+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
65+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
66+
vae ([`AutoencoderKL`]):
67+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
68+
"""
69+
70+
model_cpu_offload_seq = "image_encoder->image_embedder"
71+
_optional_components = []
72+
_callback_tensor_inputs = []
73+
74+
def __init__(
75+
self,
76+
image_encoder: SiglipVisionModel,
77+
feature_extractor: SiglipImageProcessor,
78+
image_embedder: ReduxImageEncoder,
79+
):
80+
super().__init__()
81+
82+
self.register_modules(
83+
image_encoder=image_encoder,
84+
feature_extractor=feature_extractor,
85+
image_embedder=image_embedder,
86+
)
87+
88+
def encode_image(self, image, device, num_images_per_prompt):
89+
dtype = next(self.image_encoder.parameters()).dtype
90+
image = self.feature_extractor.preprocess(
91+
images=[image], do_resize=True, return_tensors="pt", do_convert_rgb=True
92+
)
93+
image = image.to(device=device, dtype=dtype)
94+
image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
95+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
96+
97+
return image_enc_hidden_states
98+
99+
@torch.no_grad()
100+
@replace_example_docstring(EXAMPLE_DOC_STRING)
101+
def __call__(
102+
self,
103+
image: PipelineImageInput,
104+
return_dict: bool = True,
105+
):
106+
r"""
107+
Function invoked when calling the pipeline for generation.
108+
109+
Args:
110+
prompt (`str` or `List[str]`, *optional*):
111+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
112+
instead.
113+
114+
Examples:
115+
116+
Returns:
117+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
118+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
119+
images.
120+
"""
121+
122+
# 2. Define call parameters
123+
device = self._execution_device
124+
125+
image_latents = self.encode_image(image, device, 1)
126+
image_embeds = self.image_embedder(image_latents).image_embeds
127+
128+
# Offload all models
129+
self.maybe_free_model_hooks()
130+
131+
if not return_dict:
132+
return (image_embeds,)
133+
134+
return FluxPriorReduxPipelineOutput(image_embeds=image_embeds)

src/diffusers/pipelines/flux/pipeline_output.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import PIL.Image
6+
import torch
67

78
from ...utils import BaseOutput
89

@@ -19,3 +20,17 @@ class FluxPipelineOutput(BaseOutput):
1920
"""
2021

2122
images: Union[List[PIL.Image.Image], np.ndarray]
23+
24+
25+
@dataclass
26+
class FluxPriorReduxPipelineOutput(BaseOutput):
27+
"""
28+
Output class for Flux Prior Redux pipelines.
29+
30+
Args:
31+
images (`List[PIL.Image.Image]` or `np.ndarray`)
32+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
33+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
34+
"""
35+
36+
image_embeds: torch.Tensor

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,21 @@ def from_pretrained(cls, *args, **kwargs):
482482
requires_backends(cls, ["torch", "transformers"])
483483

484484

485+
class FluxPriorReduxPipeline(metaclass=DummyObject):
486+
_backends = ["torch", "transformers"]
487+
488+
def __init__(self, *args, **kwargs):
489+
requires_backends(self, ["torch", "transformers"])
490+
491+
@classmethod
492+
def from_config(cls, *args, **kwargs):
493+
requires_backends(cls, ["torch", "transformers"])
494+
495+
@classmethod
496+
def from_pretrained(cls, *args, **kwargs):
497+
requires_backends(cls, ["torch", "transformers"])
498+
499+
485500
class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
486501
_backends = ["torch", "transformers"]
487502

@@ -1172,6 +1187,21 @@ def from_pretrained(cls, *args, **kwargs):
11721187
requires_backends(cls, ["torch", "transformers"])
11731188

11741189

1190+
class ReduxImageEncoder(metaclass=DummyObject):
1191+
_backends = ["torch", "transformers"]
1192+
1193+
def __init__(self, *args, **kwargs):
1194+
requires_backends(self, ["torch", "transformers"])
1195+
1196+
@classmethod
1197+
def from_config(cls, *args, **kwargs):
1198+
requires_backends(cls, ["torch", "transformers"])
1199+
1200+
@classmethod
1201+
def from_pretrained(cls, *args, **kwargs):
1202+
requires_backends(cls, ["torch", "transformers"])
1203+
1204+
11751205
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
11761206
_backends = ["torch", "transformers"]
11771207

0 commit comments

Comments
 (0)