44from ..schedulers import FlowMatchScheduler
55from .base import BasePipeline
66from typing import List
7- import math
87import torch
98from tqdm import tqdm
109import numpy as np
1110from PIL import Image
12- import cv2
1311from ..models .tiler import FastTileWorker
1412from transformers import SiglipVisionModel
1513from copy import deepcopy
@@ -33,6 +31,7 @@ def __init__(self, device="cuda", torch_dtype=torch.float16):
3331 self .controlnet : FluxMultiControlNetManager = None
3432 self .ipadapter : FluxIpAdapter = None
3533 self .ipadapter_image_encoder : SiglipVisionModel = None
34+ self .infinityou_processor : InfinitYou = None
3635 self .model_names = ['text_encoder_1' , 'text_encoder_2' , 'dit' , 'vae_decoder' , 'vae_encoder' , 'controlnet' , 'ipadapter' , 'ipadapter_image_encoder' ]
3736
3837
@@ -167,16 +166,7 @@ def fetch_models(self, model_manager: ModelManager, controlnet_config_units: Lis
167166 # InfiniteYou
168167 self .image_proj_model = model_manager .fetch_model ("infiniteyou_image_projector" )
169168 if self .image_proj_model is not None :
170- from facexlib .recognition import init_recognition_model
171- from insightface .app import FaceAnalysis
172- insightface_root_path = 'models/InfiniteYou/insightface'
173- self .app_640 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
174- self .app_640 .prepare (ctx_id = 0 , det_size = (640 , 640 ))
175- self .app_320 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
176- self .app_320 .prepare (ctx_id = 0 , det_size = (320 , 320 ))
177- self .app_160 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
178- self .app_160 .prepare (ctx_id = 0 , det_size = (160 , 160 ))
179- self .arcface_model = init_recognition_model ('arcface' , device = self .device )
169+ self .infinityou_processor = InfinitYou (device = self .device )
180170
181171
182172 @staticmethod
@@ -353,66 +343,6 @@ def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_m
353343 return eligen_kwargs_posi , eligen_kwargs_nega , fg_mask , bg_mask
354344
355345
356- def draw_kps (image_pil , kps , color_list = [(255 ,0 ,0 ), (0 ,255 ,0 ), (0 ,0 ,255 ), (255 ,255 ,0 ), (255 ,0 ,255 )]):
357- stickwidth = 4
358- limbSeq = np .array ([[0 , 2 ], [1 , 2 ], [3 , 2 ], [4 , 2 ]])
359- kps = np .array (kps )
360- w , h = image_pil .size
361- out_img = np .zeros ([h , w , 3 ])
362- for i in range (len (limbSeq )):
363- index = limbSeq [i ]
364- color = color_list [index [0 ]]
365- x = kps [index ][:, 0 ]
366- y = kps [index ][:, 1 ]
367- length = ((x [0 ] - x [1 ]) ** 2 + (y [0 ] - y [1 ]) ** 2 ) ** 0.5
368- angle = math .degrees (math .atan2 (y [0 ] - y [1 ], x [0 ] - x [1 ]))
369- polygon = cv2 .ellipse2Poly ((int (np .mean (x )), int (np .mean (y ))), (int (length / 2 ), stickwidth ), int (angle ), 0 , 360 , 1 )
370- out_img = cv2 .fillConvexPoly (out_img .copy (), polygon , color )
371- out_img = (out_img * 0.6 ).astype (np .uint8 )
372- for idx_kp , kp in enumerate (kps ):
373- color = color_list [idx_kp ]
374- out_img = cv2 .circle (out_img .copy (), (int (kp [0 ]), int (kp [1 ])), 10 , color , - 1 )
375- out_img_pil = Image .fromarray (out_img .astype (np .uint8 ))
376- return out_img_pil
377-
378-
379- def extract_arcface_bgr_embedding (self , in_image , landmark ):
380- from insightface .utils import face_align
381- arc_face_image = face_align .norm_crop (in_image , landmark = np .array (landmark ), image_size = 112 )
382- arc_face_image = torch .from_numpy (arc_face_image ).unsqueeze (0 ).permute (0 , 3 , 1 , 2 ) / 255.
383- arc_face_image = 2 * arc_face_image - 1
384- arc_face_image = arc_face_image .contiguous ().to (self .device )
385- face_emb = self .arcface_model (arc_face_image )[0 ] # [512], normalized
386- return face_emb
387-
388-
389- def _detect_face (self , id_image_cv2 ):
390- face_info = self .app_640 .get (id_image_cv2 )
391- if len (face_info ) > 0 :
392- return face_info
393- face_info = self .app_320 .get (id_image_cv2 )
394- if len (face_info ) > 0 :
395- return face_info
396- face_info = self .app_160 .get (id_image_cv2 )
397- return face_info
398-
399-
400- def prepare_infinite_you (self , id_image , controlnet_image , controlnet_guidance , height , width ):
401- if id_image is None :
402- return {'id_emb' : None }, controlnet_image
403- id_image_cv2 = cv2 .cvtColor (np .array (id_image ), cv2 .COLOR_RGB2BGR )
404- face_info = self ._detect_face (id_image_cv2 )
405- if len (face_info ) == 0 :
406- raise ValueError ('No face detected in the input ID image' )
407- landmark = sorted (face_info , key = lambda x :(x ['bbox' ][2 ]- x ['bbox' ][0 ])* (x ['bbox' ][3 ]- x ['bbox' ][1 ]))[- 1 ]['kps' ] # only use the maximum face
408- id_emb = self .extract_arcface_bgr_embedding (id_image_cv2 , landmark )
409- id_emb = self .image_proj_model (id_emb .unsqueeze (0 ).reshape ([1 , - 1 , 512 ]).to (dtype = self .torch_dtype ))
410- if controlnet_image is None :
411- controlnet_image = Image .fromarray (np .zeros ([height , width , 3 ]).astype (np .uint8 ))
412- controlnet_guidance = torch .Tensor ([controlnet_guidance ]).to (device = self .device , dtype = self .torch_dtype )
413- return {'id_emb' : id_emb , 'controlnet_guidance' : controlnet_guidance }, controlnet_image
414-
415-
416346 def prepare_prompts (self , prompt , local_prompts , masks , mask_scales , t5_sequence_length , negative_prompt , cfg_scale ):
417347 # Extend prompt
418348 self .load_models_to_device (['text_encoder_1' , 'text_encoder_2' ])
@@ -423,6 +353,13 @@ def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence
423353 prompt_emb_nega = self .encode_prompt (negative_prompt , positive = False , t5_sequence_length = t5_sequence_length ) if cfg_scale != 1.0 else None
424354 prompt_emb_locals = [self .encode_prompt (prompt_local , t5_sequence_length = t5_sequence_length ) for prompt_local in local_prompts ]
425355 return prompt_emb_posi , prompt_emb_nega , prompt_emb_locals
356+
357+
358+ def prepare_infinite_you (self , id_image , controlnet_image , infinityou_guidance , height , width ):
359+ if self .infinityou_processor is not None and id_image is not None :
360+ return self .infinityou_processor .prepare_infinite_you (self .image_proj_model , id_image , controlnet_image , infinityou_guidance , height , width )
361+ else :
362+ return {}, controlnet_image
426363
427364
428365 @torch .no_grad ()
@@ -450,7 +387,6 @@ def __call__(
450387 controlnet_image = None ,
451388 controlnet_inpaint_mask = None ,
452389 enable_controlnet_on_negative = False ,
453- controlnet_guidance = 1.0 ,
454390 # IP-Adapter
455391 ipadapter_images = None ,
456392 ipadapter_scale = 1.0 ,
@@ -460,7 +396,8 @@ def __call__(
460396 enable_eligen_on_negative = False ,
461397 enable_eligen_inpaint = False ,
462398 # InfiniteYou
463- id_image = None ,
399+ infinityou_id_image = None ,
400+ infinityou_guidance = 1.0 ,
464401 # TeaCache
465402 tea_cache_l1_thresh = None ,
466403 # Tile
@@ -489,7 +426,7 @@ def __call__(
489426 extra_input = self .prepare_extra_input (latents , guidance = embedded_guidance )
490427
491428 # InfiniteYou
492- infiniteyou_kwargs , controlnet_image = self .prepare_infinite_you (id_image , controlnet_image , controlnet_guidance , height , width )
429+ infiniteyou_kwargs , controlnet_image = self .prepare_infinite_you (infinityou_id_image , controlnet_image , infinityou_guidance , height , width )
493430
494431 # Entity control
495432 eligen_kwargs_posi , eligen_kwargs_nega , fg_mask , bg_mask = self .prepare_eligen (prompt_emb_nega , eligen_entity_prompts , eligen_entity_masks , width , height , t5_sequence_length , enable_eligen_inpaint , enable_eligen_on_negative , cfg_scale )
@@ -529,7 +466,7 @@ def __call__(
529466 noise_pred_nega = lets_dance_flux (
530467 dit = self .dit , controlnet = self .controlnet ,
531468 hidden_states = latents , timestep = timestep ,
532- ** prompt_emb_nega , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs_nega , ** ipadapter_kwargs_list_nega , ** eligen_kwargs_nega ,
469+ ** prompt_emb_nega , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs_nega , ** ipadapter_kwargs_list_nega , ** eligen_kwargs_nega , ** infiniteyou_kwargs ,
533470 )
534471 noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega )
535472 else :
@@ -549,6 +486,58 @@ def __call__(
549486 # Offload all models
550487 self .load_models_to_device ([])
551488 return image
489+
490+
491+
492+ class InfinitYou :
493+ def __init__ (self , device = "cuda" , torch_dtype = torch .bfloat16 ):
494+ from facexlib .recognition import init_recognition_model
495+ from insightface .app import FaceAnalysis
496+ self .device = device
497+ self .torch_dtype = torch_dtype
498+ insightface_root_path = 'models/InfiniteYou/insightface'
499+ self .app_640 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
500+ self .app_640 .prepare (ctx_id = 0 , det_size = (640 , 640 ))
501+ self .app_320 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
502+ self .app_320 .prepare (ctx_id = 0 , det_size = (320 , 320 ))
503+ self .app_160 = FaceAnalysis (name = 'antelopev2' , root = insightface_root_path , providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ])
504+ self .app_160 .prepare (ctx_id = 0 , det_size = (160 , 160 ))
505+ self .arcface_model = init_recognition_model ('arcface' , device = self .device )
506+
507+ def _detect_face (self , id_image_cv2 ):
508+ face_info = self .app_640 .get (id_image_cv2 )
509+ if len (face_info ) > 0 :
510+ return face_info
511+ face_info = self .app_320 .get (id_image_cv2 )
512+ if len (face_info ) > 0 :
513+ return face_info
514+ face_info = self .app_160 .get (id_image_cv2 )
515+ return face_info
516+
517+ def extract_arcface_bgr_embedding (self , in_image , landmark ):
518+ from insightface .utils import face_align
519+ arc_face_image = face_align .norm_crop (in_image , landmark = np .array (landmark ), image_size = 112 )
520+ arc_face_image = torch .from_numpy (arc_face_image ).unsqueeze (0 ).permute (0 , 3 , 1 , 2 ) / 255.
521+ arc_face_image = 2 * arc_face_image - 1
522+ arc_face_image = arc_face_image .contiguous ().to (self .device )
523+ face_emb = self .arcface_model (arc_face_image )[0 ] # [512], normalized
524+ return face_emb
525+
526+ def prepare_infinite_you (self , model , id_image , controlnet_image , infinityou_guidance , height , width ):
527+ import cv2
528+ if id_image is None :
529+ return {'id_emb' : None }, controlnet_image
530+ id_image_cv2 = cv2 .cvtColor (np .array (id_image ), cv2 .COLOR_RGB2BGR )
531+ face_info = self ._detect_face (id_image_cv2 )
532+ if len (face_info ) == 0 :
533+ raise ValueError ('No face detected in the input ID image' )
534+ landmark = sorted (face_info , key = lambda x :(x ['bbox' ][2 ]- x ['bbox' ][0 ])* (x ['bbox' ][3 ]- x ['bbox' ][1 ]))[- 1 ]['kps' ] # only use the maximum face
535+ id_emb = self .extract_arcface_bgr_embedding (id_image_cv2 , landmark )
536+ id_emb = model (id_emb .unsqueeze (0 ).reshape ([1 , - 1 , 512 ]).to (dtype = self .torch_dtype ))
537+ if controlnet_image is None :
538+ controlnet_image = Image .fromarray (np .zeros ([height , width , 3 ]).astype (np .uint8 ))
539+ infinityou_guidance = torch .Tensor ([infinityou_guidance ]).to (device = self .device , dtype = self .torch_dtype )
540+ return {'id_emb' : id_emb , 'infinityou_guidance' : infinityou_guidance }, controlnet_image
552541
553542
554543class TeaCache :
@@ -612,7 +601,7 @@ def lets_dance_flux(
612601 entity_masks = None ,
613602 ipadapter_kwargs_list = {},
614603 id_emb = None ,
615- controlnet_guidance = None ,
604+ infinityou_guidance = None ,
616605 tea_cache : TeaCache = None ,
617606 ** kwargs
618607):
@@ -659,7 +648,7 @@ def flux_forward_fn(hl, hr, wl, wr):
659648 }
660649 if id_emb is not None :
661650 controlnet_text_ids = torch .zeros (id_emb .shape [0 ], id_emb .shape [1 ], 3 ).to (device = hidden_states .device , dtype = hidden_states .dtype )
662- controlnet_extra_kwargs .update ({"prompt_emb" : id_emb , 'text_ids' : controlnet_text_ids , 'guidance' : controlnet_guidance })
651+ controlnet_extra_kwargs .update ({"prompt_emb" : id_emb , 'text_ids' : controlnet_text_ids , 'guidance' : infinityou_guidance })
663652 controlnet_res_stack , controlnet_single_res_stack = controlnet (
664653 controlnet_frames , ** controlnet_extra_kwargs
665654 )
0 commit comments