44from ..schedulers import FlowMatchScheduler
55from .base import BasePipeline
66from typing import List
7+ import math
78import torch
89from tqdm import tqdm
910import numpy as np
1011from PIL import Image
12+ import cv2
1113from ..models .tiler import FastTileWorker
1214from transformers import SiglipVisionModel
1315from copy import deepcopy
@@ -162,6 +164,20 @@ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: Lis
162164 self .ipadapter = model_manager .fetch_model ("flux_ipadapter" )
163165 self .ipadapter_image_encoder = model_manager .fetch_model ("siglip_vision_model" )
164166
167+ # InfiniteYou
168+ self .image_proj_model = model_manager .fetch_model ("infiniteyou_image_projector" )
169+ if self .image_proj_model is not None :
170+ from facexlib .recognition import init_recognition_model
171+ from insightface .app import FaceAnalysis
172+ insightface_root_path = 'models/InfiniteYou/insightface'
173+ self .app_640 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
174+ self .app_640 .prepare (ctx_id = 0 , det_size = (640 , 640 ))
175+ self .app_320 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
176+ self .app_320 .prepare (ctx_id = 0 , det_size = (320 , 320 ))
177+ self .app_160 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
178+ self .app_160 .prepare (ctx_id = 0 , det_size = (160 , 160 ))
179+ self .arcface_model = init_recognition_model ('arcface' , device = self .device )
180+
165181
166182 @staticmethod
167183 def from_model_manager (model_manager : ModelManager , controlnet_config_units : List [ControlNetConfigUnit ]= [], prompt_refiner_classes = [], prompt_extender_classes = [], device = None , torch_dtype = None ):
@@ -337,6 +353,66 @@ def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_m
337353 return eligen_kwargs_posi , eligen_kwargs_nega , fg_mask , bg_mask
338354
339355
356+ def draw_kps (image_pil , kps , color_list = [(255 ,0 ,0 ), (0 ,255 ,0 ), (0 ,0 ,255 ), (255 ,255 ,0 ), (255 ,0 ,255 )]):
357+ stickwidth = 4
358+ limbSeq = np .array ([[0 , 2 ], [1 , 2 ], [3 , 2 ], [4 , 2 ]])
359+ kps = np .array (kps )
360+ w , h = image_pil .size
361+ out_img = np .zeros ([h , w , 3 ])
362+ for i in range (len (limbSeq )):
363+ index = limbSeq [i ]
364+ color = color_list [index [0 ]]
365+ x = kps [index ][:, 0 ]
366+ y = kps [index ][:, 1 ]
367+ length = ((x [0 ] - x [1 ]) ** 2 + (y [0 ] - y [1 ]) ** 2 ) ** 0.5
368+ angle = math .degrees (math .atan2 (y [0 ] - y [1 ], x [0 ] - x [1 ]))
369+ polygon = cv2 .ellipse2Poly ((int (np .mean (x )), int (np .mean (y ))), (int (length / 2 ), stickwidth ), int (angle ), 0 , 360 , 1 )
370+ out_img = cv2 .fillConvexPoly (out_img .copy (), polygon , color )
371+ out_img = (out_img * 0.6 ).astype (np .uint8 )
372+ for idx_kp , kp in enumerate (kps ):
373+ color = color_list [idx_kp ]
374+ out_img = cv2 .circle (out_img .copy (), (int (kp [0 ]), int (kp [1 ])), 10 , color , - 1 )
375+ out_img_pil = Image .fromarray (out_img .astype (np .uint8 ))
376+ return out_img_pil
377+
378+
379+ def extract_arcface_bgr_embedding (self , in_image , landmark ):
380+ from insightface .utils import face_align
381+ arc_face_image = face_align .norm_crop (in_image , landmark = np .array (landmark ), image_size = 112 )
382+ arc_face_image = torch .from_numpy (arc_face_image ).unsqueeze (0 ).permute (0 , 3 , 1 , 2 ) / 255.
383+ arc_face_image = 2 * arc_face_image - 1
384+ arc_face_image = arc_face_image .contiguous ().to (self .device )
385+ face_emb = self .arcface_model (arc_face_image )[0 ] # [512], normalized
386+ return face_emb
387+
388+
389+ def _detect_face (self , id_image_cv2 ):
390+ face_info = self .app_640 .get (id_image_cv2 )
391+ if len (face_info ) > 0 :
392+ return face_info
393+ face_info = self .app_320 .get (id_image_cv2 )
394+ if len (face_info ) > 0 :
395+ return face_info
396+ face_info = self .app_160 .get (id_image_cv2 )
397+ return face_info
398+
399+
400+ def prepare_infinite_you (self , id_image , controlnet_image , controlnet_guidance , height , width ):
401+ if id_image is None :
402+ return {'id_emb' : None }, controlnet_image
403+ id_image_cv2 = cv2 .cvtColor (np .array (id_image ), cv2 .COLOR_RGB2BGR )
404+ face_info = self ._detect_face (id_image_cv2 )
405+ if len (face_info ) == 0 :
406+ raise ValueError ('No face detected in the input ID image' )
407+ landmark = sorted (face_info , key = lambda x :(x ['bbox' ][2 ]- x ['bbox' ][0 ])* (x ['bbox' ][3 ]- x ['bbox' ][1 ]))[- 1 ]['kps' ] # only use the maximum face
408+ id_emb = self .extract_arcface_bgr_embedding (id_image_cv2 , landmark )
409+ id_emb = self .image_proj_model (id_emb .unsqueeze (0 ).reshape ([1 , - 1 , 512 ]).to (dtype = self .torch_dtype ))
410+ if controlnet_image is None :
411+ controlnet_image = Image .fromarray (np .zeros ([height , width , 3 ]).astype (np .uint8 ))
412+ controlnet_guidance = torch .Tensor ([controlnet_guidance ]).to (device = self .device , dtype = self .torch_dtype )
413+ return {'id_emb' : id_emb , 'controlnet_guidance' : controlnet_guidance }, controlnet_image
414+
415+
340416 def prepare_prompts (self , prompt , local_prompts , masks , mask_scales , t5_sequence_length , negative_prompt , cfg_scale ):
341417 # Extend prompt
342418 self .load_models_to_device (['text_encoder_1' , 'text_encoder_2' ])
@@ -374,6 +450,7 @@ def __call__(
374450 controlnet_image = None ,
375451 controlnet_inpaint_mask = None ,
376452 enable_controlnet_on_negative = False ,
453+ controlnet_guidance = 1.0 ,
377454 # IP-Adapter
378455 ipadapter_images = None ,
379456 ipadapter_scale = 1.0 ,
@@ -382,6 +459,8 @@ def __call__(
382459 eligen_entity_masks = None ,
383460 enable_eligen_on_negative = False ,
384461 enable_eligen_inpaint = False ,
462+ # InfiniteYou
463+ id_image = None ,
385464 # TeaCache
386465 tea_cache_l1_thresh = None ,
387466 # Tile
@@ -409,6 +488,9 @@ def __call__(
409488 # Extra input
410489 extra_input = self .prepare_extra_input (latents , guidance = embedded_guidance )
411490
491+ # InfiniteYou
492+ infiniteyou_kwargs , controlnet_image = self .prepare_infinite_you (id_image , controlnet_image , controlnet_guidance , height , width )
493+
412494 # Entity control
413495 eligen_kwargs_posi , eligen_kwargs_nega , fg_mask , bg_mask = self .prepare_eligen (prompt_emb_nega , eligen_entity_prompts , eligen_entity_masks , width , height , t5_sequence_length , enable_eligen_inpaint , enable_eligen_on_negative , cfg_scale )
414496
@@ -430,7 +512,7 @@ def __call__(
430512 inference_callback = lambda prompt_emb_posi , controlnet_kwargs : lets_dance_flux (
431513 dit = self .dit , controlnet = self .controlnet ,
432514 hidden_states = latents , timestep = timestep ,
433- ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs ,
515+ ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs , ** infiniteyou_kwargs
434516 )
435517 noise_pred_posi = self .control_noise_via_local_prompts (
436518 prompt_emb_posi , prompt_emb_locals , masks , mask_scales , inference_callback ,
@@ -529,6 +611,8 @@ def lets_dance_flux(
529611 entity_prompt_emb = None ,
530612 entity_masks = None ,
531613 ipadapter_kwargs_list = {},
614+ id_emb = None ,
615+ controlnet_guidance = None ,
532616 tea_cache : TeaCache = None ,
533617 ** kwargs
534618):
@@ -573,6 +657,9 @@ def flux_forward_fn(hl, hr, wl, wr):
573657 "tile_size" : tile_size ,
574658 "tile_stride" : tile_stride ,
575659 }
660+ if id_emb is not None :
661+ controlnet_text_ids = torch .zeros (id_emb .shape [0 ], id_emb .shape [1 ], 3 ).to (device = hidden_states .device , dtype = hidden_states .dtype )
662+ controlnet_extra_kwargs .update ({"prompt_emb" : id_emb , 'text_ids' : controlnet_text_ids , 'guidance' : controlnet_guidance })
576663 controlnet_res_stack , controlnet_single_res_stack = controlnet (
577664 controlnet_frames , ** controlnet_extra_kwargs
578665 )
0 commit comments