Skip to content

Commit 0636e9d

Browse files
committed
up
1 parent 9381dd6 commit 0636e9d

File tree

6 files changed

+126
-24
lines changed

6 files changed

+126
-24
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@
364364
else:
365365
_import_structure["modular_pipelines"].extend(
366366
[
367+
"FluxAutoBlocks",
368+
"FluxModularPipeline",
367369
"StableDiffusionXLAutoBlocks",
368370
"StableDiffusionXLModularPipeline",
369371
"WanAutoBlocks",
@@ -999,6 +1001,8 @@
9991001
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
10001002
else:
10011003
from .modular_pipelines import (
1004+
FluxAutoBlocks,
1005+
FluxModularPipeline,
10021006
StableDiffusionXLAutoBlocks,
10031007
StableDiffusionXLModularPipeline,
10041008
WanAutoBlocks,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
]
4242
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
4343
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
44+
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
4445
_import_structure["components_manager"] = ["ComponentsManager"]
4546

4647
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -51,6 +52,7 @@
5152
from ..utils.dummy_pt_objects import * # noqa F403
5253
else:
5354
from .components_manager import ComponentsManager
55+
from .flux import FluxAutoBlocks, FluxModularPipeline
5456
from .modular_pipeline import (
5557
AutoPipelineBlocks,
5658
BlockState,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["encoders"] = ["FluxTextEncoderStep"]
25+
_import_structure["modular_blocks"] = [
26+
"ALL_BLOCKS",
27+
"AUTO_BLOCKS",
28+
"TEXT2IMAGE_BLOCKS",
29+
"FluxAutoBeforeDenoiseStep",
30+
"FluxAutoBlocks",
31+
"FluxAutoBlocks",
32+
"FluxAutoDecodeStep",
33+
"FluxAutoDenoiseStep",
34+
]
35+
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
36+
37+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
38+
try:
39+
if not (is_transformers_available() and is_torch_available()):
40+
raise OptionalDependencyNotAvailable()
41+
except OptionalDependencyNotAvailable:
42+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
43+
else:
44+
from .encoders import FluxTextEncoderStep
45+
from .modular_blocks import (
46+
ALL_BLOCKS,
47+
AUTO_BLOCKS,
48+
TEXT2IMAGE_BLOCKS,
49+
FluxAutoBeforeDenoiseStep,
50+
FluxAutoBlocks,
51+
FluxAutoDecodeStep,
52+
FluxAutoDenoiseStep,
53+
)
54+
from .modular_pipeline import FluxModularPipeline
55+
else:
56+
import sys
57+
58+
sys.modules[__name__] = _LazyModule(
59+
__name__,
60+
globals()["__file__"],
61+
_import_structure,
62+
module_spec=__spec__,
63+
)
64+
65+
for name, value in _dummy_objects.items():
66+
setattr(sys.modules[__name__], name, value)

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,28 +103,28 @@ def calculate_shift(
103103
return mu
104104

105105

106-
# Copied from diffusers.pipelines.flux.pipeline_flux._pack_latents
106+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
107107
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
108-
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
109-
latents = latents.permute(0, 2, 4, 1, 3, 5)
110-
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
108+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
109+
latents = latents.permute(0, 2, 4, 1, 3, 5)
110+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
111111

112-
return latents
112+
return latents
113113

114114

115-
# Copied from diffusers.pipelines.flux.pipeline_flux._prepare_latent_image_ids
115+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
116116
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
117-
latent_image_ids = torch.zeros(height, width, 3)
118-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
119-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
117+
latent_image_ids = torch.zeros(height, width, 3)
118+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
119+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
120120

121-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
121+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
122122

123-
latent_image_ids = latent_image_ids.reshape(
124-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
125-
)
123+
latent_image_ids = latent_image_ids.reshape(
124+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
125+
)
126126

127-
return latent_image_ids.to(device=device, dtype=dtype)
127+
return latent_image_ids.to(device=device, dtype=dtype)
128128

129129

130130
class FluxInputStep(PipelineBlock):

src/diffusers/modular_pipelines/flux/decoders.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,21 @@
2929
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3030

3131

32-
# Copied from diffusers.pipelines.flux.pipeline_flux._unpack_latents
32+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
3333
def _unpack_latents(latents, height, width, vae_scale_factor):
34-
batch_size, num_patches, channels = latents.shape
34+
batch_size, num_patches, channels = latents.shape
3535

36-
# VAE applies 8x compression on images but we must also account for packing which requires
37-
# latent height and width to be divisible by 2.
38-
height = 2 * (int(height) // (vae_scale_factor * 2))
39-
width = 2 * (int(width) // (vae_scale_factor * 2))
36+
# VAE applies 8x compression on images but we must also account for packing which requires
37+
# latent height and width to be divisible by 2.
38+
height = 2 * (int(height) // (vae_scale_factor * 2))
39+
width = 2 * (int(width) // (vae_scale_factor * 2))
4040

41-
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
42-
latents = latents.permute(0, 3, 1, 4, 2, 5)
41+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
42+
latents = latents.permute(0, 3, 1, 4, 2, 5)
4343

44-
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
44+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
4545

46-
return latents
46+
return latents
4747

4848

4949
class FluxDecodeStep(PipelineBlock):

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,36 @@
22
from ..utils import DummyObject, requires_backends
33

44

5+
class FluxAutoBlocks(metaclass=DummyObject):
6+
_backends = ["torch", "transformers"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["torch", "transformers"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["torch", "transformers"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["torch", "transformers"])
18+
19+
20+
class FluxModularPipeline(metaclass=DummyObject):
21+
_backends = ["torch", "transformers"]
22+
23+
def __init__(self, *args, **kwargs):
24+
requires_backends(self, ["torch", "transformers"])
25+
26+
@classmethod
27+
def from_config(cls, *args, **kwargs):
28+
requires_backends(cls, ["torch", "transformers"])
29+
30+
@classmethod
31+
def from_pretrained(cls, *args, **kwargs):
32+
requires_backends(cls, ["torch", "transformers"])
33+
34+
535
class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
636
_backends = ["torch", "transformers"]
737

0 commit comments

Comments
 (0)