Skip to content

Commit 2829679

Browse files
committed
update
1 parent e564abe commit 2829679

File tree

4 files changed

+94
-8
lines changed

4 files changed

+94
-8
lines changed

scripts/convert_flux_to_diffusers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
3838
parser.add_argument("--filename", default="flux.safetensors", type=str)
3939
parser.add_argument("--checkpoint_path", default=None, type=str)
40+
parser.add_argument("--in_channels", type=int, default=64)
41+
parser.add_argument("--out_channels", type=int, default=None)
4042
parser.add_argument("--vae", action="store_true")
4143
parser.add_argument("--transformer", action="store_true")
4244
parser.add_argument("--output_path", type=str)
@@ -282,7 +284,9 @@ def main(args):
282284
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
283285
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
284286
)
285-
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
287+
transformer = FluxTransformer2DModel(
288+
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
289+
)
286290
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
287291

288292
print(

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(
238238
self,
239239
patch_size: int = 1,
240240
in_channels: int = 64,
241+
out_channels: Optional[int] = None,
241242
num_layers: int = 19,
242243
num_single_layers: int = 38,
243244
attention_head_dim: int = 128,
@@ -248,7 +249,7 @@ def __init__(
248249
axes_dims_rope: Tuple[int] = (16, 56, 56),
249250
):
250251
super().__init__()
251-
self.out_channels = in_channels
252+
self.out_channels = out_channels or in_channels
252253
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
253254

254255
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
@@ -261,7 +262,7 @@ def __init__(
261262
)
262263

263264
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
264-
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
265+
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
265266

266267
self.transformer_blocks = nn.ModuleList(
267268
[
@@ -449,13 +450,15 @@ def forward(
449450
logger.warning(
450451
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
451452
)
453+
452454
hidden_states = self.x_embedder(hidden_states)
453455

454456
timestep = timestep.to(hidden_states.dtype) * 1000
455457
if guidance is not None:
456458
guidance = guidance.to(hidden_states.dtype) * 1000
457459
else:
458460
guidance = None
461+
459462
temb = (
460463
self.time_text_embed(timestep, pooled_projections)
461464
if guidance is None

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
2121

22-
from ...image_processor import VaeImageProcessor
22+
from ...image_processor import PipelineImageInput, VaeImageProcessor
2323
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
2424
from ...models.autoencoders import AutoencoderKL
2525
from ...models.transformers import FluxTransformer2DModel
@@ -513,7 +513,7 @@ def prepare_latents(
513513
shape = (batch_size, num_channels_latents, height, width)
514514

515515
if latents is not None:
516-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
516+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
517517
return latents.to(device=device, dtype=dtype), latent_image_ids
518518

519519
if isinstance(generator, list) and len(generator) != batch_size:
@@ -529,6 +529,41 @@ def prepare_latents(
529529

530530
return latents, latent_image_ids
531531

532+
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
533+
def prepare_image(
534+
self,
535+
image,
536+
width,
537+
height,
538+
batch_size,
539+
num_images_per_prompt,
540+
device,
541+
dtype,
542+
do_classifier_free_guidance=False,
543+
guess_mode=False,
544+
):
545+
if isinstance(image, torch.Tensor):
546+
pass
547+
else:
548+
image = self.image_processor.preprocess(image, height=height, width=width)
549+
550+
image_batch_size = image.shape[0]
551+
552+
if image_batch_size == 1:
553+
repeat_by = batch_size
554+
else:
555+
# image batch size is the same as prompt batch size
556+
repeat_by = num_images_per_prompt
557+
558+
image = image.repeat_interleave(repeat_by, dim=0)
559+
560+
image = image.to(device=device, dtype=dtype)
561+
562+
if do_classifier_free_guidance and not guess_mode:
563+
image = torch.cat([image] * 2)
564+
565+
return image
566+
532567
@property
533568
def guidance_scale(self):
534569
return self._guidance_scale
@@ -556,9 +591,11 @@ def __call__(
556591
num_inference_steps: int = 28,
557592
timesteps: List[int] = None,
558593
guidance_scale: float = 3.5,
594+
control_image: PipelineImageInput = None,
559595
num_images_per_prompt: Optional[int] = 1,
560596
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
561597
latents: Optional[torch.FloatTensor] = None,
598+
control_latents: Optional[torch.FloatTensor] = None,
562599
prompt_embeds: Optional[torch.FloatTensor] = None,
563600
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
564601
output_type: Optional[str] = "pil",
@@ -595,6 +632,14 @@ def __call__(
595632
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
596633
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
597634
usually at the expense of lower image quality.
635+
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
636+
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
637+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
638+
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
639+
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
640+
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
641+
images must be passed as a list such that each element of the list can be correctly batched for input
642+
to a single ControlNet.
598643
num_images_per_prompt (`int`, *optional*, defaults to 1):
599644
The number of images to generate per prompt.
600645
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -667,6 +712,7 @@ def __call__(
667712

668713
device = self._execution_device
669714

715+
# 3. Prepare text embeddings
670716
lora_scale = (
671717
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
672718
)
@@ -686,7 +732,35 @@ def __call__(
686732
)
687733

688734
# 4. Prepare latent variables
689-
num_channels_latents = self.transformer.config.in_channels // 4
735+
num_channels_latents = (
736+
self.transformer.config.in_channels // 4
737+
if control_image is None
738+
else self.transformer.config.in_channels // 8
739+
)
740+
741+
if control_image is not None and control_latents is None:
742+
control_image = self.prepare_image(
743+
image=control_image,
744+
width=width,
745+
height=height,
746+
batch_size=batch_size * num_images_per_prompt,
747+
num_images_per_prompt=num_images_per_prompt,
748+
device=device,
749+
dtype=self.vae.dtype,
750+
)
751+
752+
control_latents = self.vae.encode(control_image).latent_dist.sample(generator=generator)
753+
control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
754+
755+
height_control_image, width_control_image = control_latents.shape[2:]
756+
control_latents = self._pack_latents(
757+
control_latents,
758+
batch_size * num_images_per_prompt,
759+
num_channels_latents,
760+
height_control_image,
761+
width_control_image,
762+
)
763+
690764
latents, latent_image_ids = self.prepare_latents(
691765
batch_size * num_images_per_prompt,
692766
num_channels_latents,
@@ -732,11 +806,16 @@ def __call__(
732806
if self.interrupt:
733807
continue
734808

809+
if control_latents is not None:
810+
latent_model_input = torch.cat([latents, control_latents], dim=2)
811+
else:
812+
latent_model_input = latents
813+
735814
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
736815
timestep = t.expand(latents.shape[0]).to(latents.dtype)
737816

738817
noise_pred = self.transformer(
739-
hidden_states=latents,
818+
hidden_states=latent_model_input,
740819
timestep=timestep / 1000,
741820
guidance=guidance,
742821
pooled_projections=pooled_prompt_embeds,
@@ -774,7 +853,6 @@ def __call__(
774853

775854
if output_type == "latent":
776855
image = latents
777-
778856
else:
779857
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
780858
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ def __call__(
736736
device = self._execution_device
737737
dtype = self.transformer.dtype
738738

739+
# 3. Prepare text embeddings
739740
lora_scale = (
740741
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
741742
)

0 commit comments

Comments
 (0)