@@ -31,6 +31,7 @@ def __init__(self, device="cuda", torch_dtype=torch.float16):
3131 self .controlnet : FluxMultiControlNetManager = None
3232 self .ipadapter : FluxIpAdapter = None
3333 self .ipadapter_image_encoder : SiglipVisionModel = None
34+ self .infinityou_processor : InfinitYou = None
3435 self .model_names = ['text_encoder_1' , 'text_encoder_2' , 'dit' , 'vae_decoder' , 'vae_encoder' , 'controlnet' , 'ipadapter' , 'ipadapter_image_encoder' ]
3536
3637
@@ -162,6 +163,11 @@ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: Lis
162163 self .ipadapter = model_manager .fetch_model ("flux_ipadapter" )
163164 self .ipadapter_image_encoder = model_manager .fetch_model ("siglip_vision_model" )
164165
166+ # InfiniteYou
167+ self .image_proj_model = model_manager .fetch_model ("infiniteyou_image_projector" )
168+ if self .image_proj_model is not None :
169+ self .infinityou_processor = InfinitYou (device = self .device )
170+
165171
166172 @staticmethod
167173 def from_model_manager (model_manager : ModelManager , controlnet_config_units : List [ControlNetConfigUnit ]= [], prompt_refiner_classes = [], prompt_extender_classes = [], device = None , torch_dtype = None ):
@@ -347,6 +353,13 @@ def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence
347353 prompt_emb_nega = self .encode_prompt (negative_prompt , positive = False , t5_sequence_length = t5_sequence_length ) if cfg_scale != 1.0 else None
348354 prompt_emb_locals = [self .encode_prompt (prompt_local , t5_sequence_length = t5_sequence_length ) for prompt_local in local_prompts ]
349355 return prompt_emb_posi , prompt_emb_nega , prompt_emb_locals
356+
357+
358+ def prepare_infinite_you (self , id_image , controlnet_image , infinityou_guidance , height , width ):
359+ if self .infinityou_processor is not None and id_image is not None :
360+ return self .infinityou_processor .prepare_infinite_you (self .image_proj_model , id_image , controlnet_image , infinityou_guidance , height , width )
361+ else :
362+ return {}, controlnet_image
350363
351364
352365 @torch .no_grad ()
@@ -382,6 +395,9 @@ def __call__(
382395 eligen_entity_masks = None ,
383396 enable_eligen_on_negative = False ,
384397 enable_eligen_inpaint = False ,
398+ # InfiniteYou
399+ infinityou_id_image = None ,
400+ infinityou_guidance = 1.0 ,
385401 # TeaCache
386402 tea_cache_l1_thresh = None ,
387403 # Tile
@@ -409,6 +425,9 @@ def __call__(
409425 # Extra input
410426 extra_input = self .prepare_extra_input (latents , guidance = embedded_guidance )
411427
428+ # InfiniteYou
429+ infiniteyou_kwargs , controlnet_image = self .prepare_infinite_you (infinityou_id_image , controlnet_image , infinityou_guidance , height , width )
430+
412431 # Entity control
413432 eligen_kwargs_posi , eligen_kwargs_nega , fg_mask , bg_mask = self .prepare_eligen (prompt_emb_nega , eligen_entity_prompts , eligen_entity_masks , width , height , t5_sequence_length , enable_eligen_inpaint , enable_eligen_on_negative , cfg_scale )
414433
@@ -430,7 +449,7 @@ def __call__(
430449 inference_callback = lambda prompt_emb_posi , controlnet_kwargs : lets_dance_flux (
431450 dit = self .dit , controlnet = self .controlnet ,
432451 hidden_states = latents , timestep = timestep ,
433- ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs ,
452+ ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs , ** infiniteyou_kwargs
434453 )
435454 noise_pred_posi = self .control_noise_via_local_prompts (
436455 prompt_emb_posi , prompt_emb_locals , masks , mask_scales , inference_callback ,
@@ -447,7 +466,7 @@ def __call__(
447466 noise_pred_nega = lets_dance_flux (
448467 dit = self .dit , controlnet = self .controlnet ,
449468 hidden_states = latents , timestep = timestep ,
450- ** prompt_emb_nega , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs_nega , ** ipadapter_kwargs_list_nega , ** eligen_kwargs_nega ,
469+ ** prompt_emb_nega , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs_nega , ** ipadapter_kwargs_list_nega , ** eligen_kwargs_nega , ** infiniteyou_kwargs ,
451470 )
452471 noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega )
453472 else :
@@ -467,6 +486,58 @@ def __call__(
467486 # Offload all models
468487 self .load_models_to_device ([])
469488 return image
489+
490+
491+
492+ class InfinitYou :
493+ def __init__ (self , device = "cuda" , torch_dtype = torch .bfloat16 ):
494+ from facexlib .recognition import init_recognition_model
495+ from insightface .app import FaceAnalysis
496+ self .device = device
497+ self .torch_dtype = torch_dtype
498+ insightface_root_path = 'models/InfiniteYou/insightface'
499+ self .app_640 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
500+ self .app_640 .prepare (ctx_id = 0 , det_size = (640 , 640 ))
501+ self .app_320 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
502+ self .app_320 .prepare (ctx_id = 0 , det_size = (320 , 320 ))
503+ self .app_160 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
504+ self .app_160 .prepare (ctx_id = 0 , det_size = (160 , 160 ))
505+ self .arcface_model = init_recognition_model ('arcface' , device = self .device )
506+
507+ def _detect_face (self , id_image_cv2 ):
508+ face_info = self .app_640 .get (id_image_cv2 )
509+ if len (face_info ) > 0 :
510+ return face_info
511+ face_info = self .app_320 .get (id_image_cv2 )
512+ if len (face_info ) > 0 :
513+ return face_info
514+ face_info = self .app_160 .get (id_image_cv2 )
515+ return face_info
516+
517+ def extract_arcface_bgr_embedding (self , in_image , landmark ):
518+ from insightface .utils import face_align
519+ arc_face_image = face_align .norm_crop (in_image , landmark = np .array (landmark ), image_size = 112 )
520+ arc_face_image = torch .from_numpy (arc_face_image ).unsqueeze (0 ).permute (0 , 3 , 1 , 2 ) / 255.
521+ arc_face_image = 2 * arc_face_image - 1
522+ arc_face_image = arc_face_image .contiguous ().to (self .device )
523+ face_emb = self .arcface_model (arc_face_image )[0 ] # [512], normalized
524+ return face_emb
525+
526+ def prepare_infinite_you (self , model , id_image , controlnet_image , infinityou_guidance , height , width ):
527+ import cv2
528+ if id_image is None :
529+ return {'id_emb' : None }, controlnet_image
530+ id_image_cv2 = cv2 .cvtColor (np .array (id_image ), cv2 .COLOR_RGB2BGR )
531+ face_info = self ._detect_face (id_image_cv2 )
532+ if len (face_info ) == 0 :
533+ raise ValueError ('No face detected in the input ID image' )
534+ landmark = sorted (face_info , key = lambda x :(x ['bbox' ][2 ]- x ['bbox' ][0 ])* (x ['bbox' ][3 ]- x ['bbox' ][1 ]))[- 1 ]['kps' ] # only use the maximum face
535+ id_emb = self .extract_arcface_bgr_embedding (id_image_cv2 , landmark )
536+ id_emb = model (id_emb .unsqueeze (0 ).reshape ([1 , - 1 , 512 ]).to (dtype = self .torch_dtype ))
537+ if controlnet_image is None :
538+ controlnet_image = Image .fromarray (np .zeros ([height , width , 3 ]).astype (np .uint8 ))
539+ infinityou_guidance = torch .Tensor ([infinityou_guidance ]).to (device = self .device , dtype = self .torch_dtype )
540+ return {'id_emb' : id_emb , 'infinityou_guidance' : infinityou_guidance }, controlnet_image
470541
471542
472543class TeaCache :
@@ -529,6 +600,8 @@ def lets_dance_flux(
529600 entity_prompt_emb = None ,
530601 entity_masks = None ,
531602 ipadapter_kwargs_list = {},
603+ id_emb = None ,
604+ infinityou_guidance = None ,
532605 tea_cache : TeaCache = None ,
533606 ** kwargs
534607):
@@ -573,6 +646,9 @@ def flux_forward_fn(hl, hr, wl, wr):
573646 "tile_size" : tile_size ,
574647 "tile_stride" : tile_stride ,
575648 }
649+ if id_emb is not None :
650+ controlnet_text_ids = torch .zeros (id_emb .shape [0 ], id_emb .shape [1 ], 3 ).to (device = hidden_states .device , dtype = hidden_states .dtype )
651+ controlnet_extra_kwargs .update ({"prompt_emb" : id_emb , 'text_ids' : controlnet_text_ids , 'guidance' : infinityou_guidance })
576652 controlnet_res_stack , controlnet_single_res_stack = controlnet (
577653 controlnet_frames , ** controlnet_extra_kwargs
578654 )
0 commit comments