Skip to content

Commit e88328d

Browse files
committed
support infiniteyou
1 parent c7035ad commit e88328d

File tree

7 files changed

+304
-1
lines changed

7 files changed

+304
-1
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ Until now, DiffSynth-Studio has supported the following models:
4242
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
4343

4444
## News
45+
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
46+
4547
- **March 25, 2025** 🔥🔥🔥 Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
4648

4749
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.

diffsynth/configs/model_config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
3838
from ..models.flux_controlnet import FluxControlNet
3939
from ..models.flux_ipadapter import FluxIpAdapter
40+
from ..models.flux_infiniteyou import InfiniteYouImageProjector
4041

4142
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
4243
from ..models.cog_dit import CogDiT
@@ -104,6 +105,8 @@
104105
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
105106
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
106107
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
108+
(None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
109+
(None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
107110
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
108111
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
109112
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
@@ -598,6 +601,25 @@
598601
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
599602
],
600603
},
604+
"InfiniteYou":{
605+
"file_list":[
606+
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
607+
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
608+
("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
609+
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
610+
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
611+
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
612+
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
613+
("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
614+
],
615+
"load_path":[
616+
[
617+
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
618+
"models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
619+
],
620+
"models/InfiniteYou/image_proj_model.bin",
621+
],
622+
},
601623
# ESRGAN
602624
"ESRGAN_x4": [
603625
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -757,6 +779,7 @@
757779
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
758780
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
759781
"InstantX/FLUX.1-dev-IP-Adapter",
782+
"InfiniteYou",
760783
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
761784
"QwenPrompt",
762785
"OmostPrompt",

diffsynth/models/flux_controlnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def from_diffusers(self, state_dict):
318318
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
319319
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
320320
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
321+
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
322+
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
321323
else:
322324
extra_kwargs = {}
323325
return state_dict_, extra_kwargs
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
# FFN
7+
def FeedForward(dim, mult=4):
8+
inner_dim = int(dim * mult)
9+
return nn.Sequential(
10+
nn.LayerNorm(dim),
11+
nn.Linear(dim, inner_dim, bias=False),
12+
nn.GELU(),
13+
nn.Linear(inner_dim, dim, bias=False),
14+
)
15+
16+
17+
def reshape_tensor(x, heads):
18+
bs, length, width = x.shape
19+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
20+
x = x.view(bs, length, heads, -1)
21+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22+
x = x.transpose(1, 2)
23+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24+
x = x.reshape(bs, heads, length, -1)
25+
return x
26+
27+
28+
class PerceiverAttention(nn.Module):
29+
30+
def __init__(self, *, dim, dim_head=64, heads=8):
31+
super().__init__()
32+
self.scale = dim_head**-0.5
33+
self.dim_head = dim_head
34+
self.heads = heads
35+
inner_dim = dim_head * heads
36+
37+
self.norm1 = nn.LayerNorm(dim)
38+
self.norm2 = nn.LayerNorm(dim)
39+
40+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
41+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
43+
44+
def forward(self, x, latents):
45+
"""
46+
Args:
47+
x (torch.Tensor): image features
48+
shape (b, n1, D)
49+
latent (torch.Tensor): latent features
50+
shape (b, n2, D)
51+
"""
52+
x = self.norm1(x)
53+
latents = self.norm2(latents)
54+
55+
b, l, _ = latents.shape
56+
57+
q = self.to_q(latents)
58+
kv_input = torch.cat((x, latents), dim=-2)
59+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60+
61+
q = reshape_tensor(q, self.heads)
62+
k = reshape_tensor(k, self.heads)
63+
v = reshape_tensor(v, self.heads)
64+
65+
# attention
66+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
67+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
68+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69+
out = weight @ v
70+
71+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72+
73+
return self.to_out(out)
74+
75+
76+
class InfiniteYouImageProjector(nn.Module):
77+
78+
def __init__(
79+
self,
80+
dim=1280,
81+
depth=4,
82+
dim_head=64,
83+
heads=20,
84+
num_queries=8,
85+
embedding_dim=512,
86+
output_dim=4096,
87+
ff_mult=4,
88+
):
89+
super().__init__()
90+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
91+
self.proj_in = nn.Linear(embedding_dim, dim)
92+
93+
self.proj_out = nn.Linear(dim, output_dim)
94+
self.norm_out = nn.LayerNorm(output_dim)
95+
96+
self.layers = nn.ModuleList([])
97+
for _ in range(depth):
98+
self.layers.append(
99+
nn.ModuleList([
100+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
101+
FeedForward(dim=dim, mult=ff_mult),
102+
]))
103+
104+
def forward(self, x):
105+
106+
latents = self.latents.repeat(x.size(0), 1, 1)
107+
108+
x = self.proj_in(x)
109+
110+
for attn, ff in self.layers:
111+
latents = attn(x, latents) + latents
112+
latents = ff(latents) + latents
113+
114+
latents = self.proj_out(latents)
115+
return self.norm_out(latents)
116+
117+
@staticmethod
118+
def state_dict_converter():
119+
return FluxInfiniteYouImageProjectorStateDictConverter()
120+
121+
122+
class FluxInfiniteYouImageProjectorStateDictConverter:
123+
124+
def __init__(self):
125+
pass
126+
127+
def from_diffusers(self, state_dict):
128+
return state_dict['image_proj']

diffsynth/pipelines/flux_image.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from ..schedulers import FlowMatchScheduler
55
from .base import BasePipeline
66
from typing import List
7+
import math
78
import torch
89
from tqdm import tqdm
910
import numpy as np
1011
from PIL import Image
12+
import cv2
1113
from ..models.tiler import FastTileWorker
1214
from transformers import SiglipVisionModel
1315
from copy import deepcopy
@@ -162,6 +164,20 @@ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: Lis
162164
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
163165
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
164166

167+
# InfiniteYou
168+
self.image_proj_model = model_manager.fetch_model("infiniteyou_image_projector")
169+
if self.image_proj_model is not None:
170+
from facexlib.recognition import init_recognition_model
171+
from insightface.app import FaceAnalysis
172+
insightface_root_path = 'models/InfiniteYou/insightface'
173+
self.app_640 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
174+
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
175+
self.app_320 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
176+
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
177+
self.app_160 = FaceAnalysis(name='antelopev2', root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
178+
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
179+
self.arcface_model = init_recognition_model('arcface', device=self.device)
180+
165181

166182
@staticmethod
167183
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
@@ -337,6 +353,66 @@ def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_m
337353
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
338354

339355

356+
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
357+
stickwidth = 4
358+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
359+
kps = np.array(kps)
360+
w, h = image_pil.size
361+
out_img = np.zeros([h, w, 3])
362+
for i in range(len(limbSeq)):
363+
index = limbSeq[i]
364+
color = color_list[index[0]]
365+
x = kps[index][:, 0]
366+
y = kps[index][:, 1]
367+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
368+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
369+
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
370+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
371+
out_img = (out_img * 0.6).astype(np.uint8)
372+
for idx_kp, kp in enumerate(kps):
373+
color = color_list[idx_kp]
374+
out_img = cv2.circle(out_img.copy(), (int(kp[0]), int(kp[1])), 10, color, -1)
375+
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
376+
return out_img_pil
377+
378+
379+
def extract_arcface_bgr_embedding(self, in_image, landmark):
380+
from insightface.utils import face_align
381+
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(landmark), image_size=112)
382+
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0, 3, 1, 2) / 255.
383+
arc_face_image = 2 * arc_face_image - 1
384+
arc_face_image = arc_face_image.contiguous().to(self.device)
385+
face_emb = self.arcface_model(arc_face_image)[0] # [512], normalized
386+
return face_emb
387+
388+
389+
def _detect_face(self, id_image_cv2):
390+
face_info = self.app_640.get(id_image_cv2)
391+
if len(face_info) > 0:
392+
return face_info
393+
face_info = self.app_320.get(id_image_cv2)
394+
if len(face_info) > 0:
395+
return face_info
396+
face_info = self.app_160.get(id_image_cv2)
397+
return face_info
398+
399+
400+
def prepare_infinite_you(self, id_image, controlnet_image, controlnet_guidance, height, width):
401+
if id_image is None:
402+
return {'id_emb': None}, controlnet_image
403+
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
404+
face_info = self._detect_face(id_image_cv2)
405+
if len(face_info) == 0:
406+
raise ValueError('No face detected in the input ID image')
407+
landmark = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]['kps'] # only use the maximum face
408+
id_emb = self.extract_arcface_bgr_embedding(id_image_cv2, landmark)
409+
id_emb = self.image_proj_model(id_emb.unsqueeze(0).reshape([1, -1, 512]).to(dtype=self.torch_dtype))
410+
if controlnet_image is None:
411+
controlnet_image = Image.fromarray(np.zeros([height, width, 3]).astype(np.uint8))
412+
controlnet_guidance = torch.Tensor([controlnet_guidance]).to(device=self.device, dtype=self.torch_dtype)
413+
return {'id_emb': id_emb, 'controlnet_guidance': controlnet_guidance}, controlnet_image
414+
415+
340416
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
341417
# Extend prompt
342418
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
@@ -374,6 +450,7 @@ def __call__(
374450
controlnet_image=None,
375451
controlnet_inpaint_mask=None,
376452
enable_controlnet_on_negative=False,
453+
controlnet_guidance=1.0,
377454
# IP-Adapter
378455
ipadapter_images=None,
379456
ipadapter_scale=1.0,
@@ -382,6 +459,8 @@ def __call__(
382459
eligen_entity_masks=None,
383460
enable_eligen_on_negative=False,
384461
enable_eligen_inpaint=False,
462+
# InfiniteYou
463+
id_image=None,
385464
# TeaCache
386465
tea_cache_l1_thresh=None,
387466
# Tile
@@ -409,6 +488,9 @@ def __call__(
409488
# Extra input
410489
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
411490

491+
# InfiniteYou
492+
infiniteyou_kwargs, controlnet_image = self.prepare_infinite_you(id_image, controlnet_image, controlnet_guidance, height, width)
493+
412494
# Entity control
413495
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
414496

@@ -430,7 +512,7 @@ def __call__(
430512
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
431513
dit=self.dit, controlnet=self.controlnet,
432514
hidden_states=latents, timestep=timestep,
433-
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs,
515+
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
434516
)
435517
noise_pred_posi = self.control_noise_via_local_prompts(
436518
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -529,6 +611,8 @@ def lets_dance_flux(
529611
entity_prompt_emb=None,
530612
entity_masks=None,
531613
ipadapter_kwargs_list={},
614+
id_emb=None,
615+
controlnet_guidance=None,
532616
tea_cache: TeaCache = None,
533617
**kwargs
534618
):
@@ -573,6 +657,9 @@ def flux_forward_fn(hl, hr, wl, wr):
573657
"tile_size": tile_size,
574658
"tile_stride": tile_stride,
575659
}
660+
if id_emb is not None:
661+
controlnet_text_ids = torch.zeros(id_emb.shape[0], id_emb.shape[1], 3).to(device=hidden_states.device, dtype=hidden_states.dtype)
662+
controlnet_extra_kwargs.update({"prompt_emb": id_emb, 'text_ids': controlnet_text_ids, 'guidance': controlnet_guidance})
576663
controlnet_res_stack, controlnet_single_res_stack = controlnet(
577664
controlnet_frames, **controlnet_extra_kwargs
578665
)

examples/InfiniteYou/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# InfiniteYou: Flexible Photo Recrafting While Preserving Your Identity
2+
We support the identity preserving feature of InfiniteYou. See [./infiniteyou.py](./infiniteyou.py) for example. The visualization of the result is shown below.
3+
4+
|Identity Image|Generated Image|
5+
|-|-|
6+
|![man_id](https://github.com/user-attachments/assets/bbc38a91-966e-49e8-a0d7-c5467582ad1f)|![man](https://github.com/user-attachments/assets/0decd5e1-5f65-437c-98fa-90991b6f23c1)|
7+
|![woman_id](https://github.com/user-attachments/assets/b2894695-690e-465b-929c-61e5dc57feeb)|![woman](https://github.com/user-attachments/assets/67cc7496-c4d3-4de1-a8f1-9eb4991d95e8)|

0 commit comments

Comments
 (0)