@@ -354,6 +354,262 @@ def unload_ip_adapter(self):
354354            )
355355        self .unet .set_attn_processor (attn_procs )
356356
357+ class  ModularIPAdapterMixin :
358+     """Mixin for handling IP Adapters.""" 
359+ 
360+     @validate_hf_hub_args  
361+     def  load_ip_adapter (
362+         self ,
363+         pretrained_model_name_or_path_or_dict : Union [str , List [str ], Dict [str , torch .Tensor ]],
364+         subfolder : Union [str , List [str ]],
365+         weight_name : Union [str , List [str ]],
366+         ** kwargs ,
367+     ):
368+         """ 
369+         Parameters: 
370+             pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): 
371+                 Can be either: 
372+ 
373+                     - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on 
374+                       the Hub. 
375+                     - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved 
376+                       with [`ModelMixin.save_pretrained`]. 
377+                     - A [torch state 
378+                       dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). 
379+             subfolder (`str` or `List[str]`): 
380+                 The subfolder location of a model file within a larger model repository on the Hub or locally. If a 
381+                 list is passed, it should have the same length as `weight_name`. 
382+             weight_name (`str` or `List[str]`): 
383+                 The name of the weight file to load. If a list is passed, it should have the same length as 
384+                 `subfolder`. 
385+             cache_dir (`Union[str, os.PathLike]`, *optional*): 
386+                 Path to a directory where a downloaded pretrained model configuration is cached if the standard cache 
387+                 is not used. 
388+             force_download (`bool`, *optional*, defaults to `False`): 
389+                 Whether or not to force the (re-)download of the model weights and configuration files, overriding the 
390+                 cached versions if they exist. 
391+ 
392+             proxies (`Dict[str, str]`, *optional*): 
393+                 A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 
394+                 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 
395+             local_files_only (`bool`, *optional*, defaults to `False`): 
396+                 Whether to only load local model weights and configuration files or not. If set to `True`, the model 
397+                 won't be downloaded from the Hub. 
398+             token (`str` or *bool*, *optional*): 
399+                 The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from 
400+                 `diffusers-cli login` (stored in `~/.huggingface`) is used. 
401+             revision (`str`, *optional*, defaults to `"main"`): 
402+                 The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier 
403+                 allowed by Git. 
404+             low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): 
405+                 Speed up model loading only loading the pretrained weights and not initializing the weights. This also 
406+                 tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. 
407+                 Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this 
408+                 argument to `True` will raise an error. 
409+         """ 
410+ 
411+         # handle the list inputs for multiple IP Adapters 
412+         if  not  isinstance (weight_name , list ):
413+             weight_name  =  [weight_name ]
414+ 
415+         if  not  isinstance (pretrained_model_name_or_path_or_dict , list ):
416+             pretrained_model_name_or_path_or_dict  =  [pretrained_model_name_or_path_or_dict ]
417+         if  len (pretrained_model_name_or_path_or_dict ) ==  1 :
418+             pretrained_model_name_or_path_or_dict  =  pretrained_model_name_or_path_or_dict  *  len (weight_name )
419+ 
420+         if  not  isinstance (subfolder , list ):
421+             subfolder  =  [subfolder ]
422+         if  len (subfolder ) ==  1 :
423+             subfolder  =  subfolder  *  len (weight_name )
424+ 
425+         if  len (weight_name ) !=  len (pretrained_model_name_or_path_or_dict ):
426+             raise  ValueError ("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length." )
427+ 
428+         if  len (weight_name ) !=  len (subfolder ):
429+             raise  ValueError ("`weight_name` and `subfolder` must have the same length." )
430+ 
431+         # Load the main state dict first. 
432+         cache_dir  =  kwargs .pop ("cache_dir" , None )
433+         force_download  =  kwargs .pop ("force_download" , False )
434+         proxies  =  kwargs .pop ("proxies" , None )
435+         local_files_only  =  kwargs .pop ("local_files_only" , None )
436+         token  =  kwargs .pop ("token" , None )
437+         revision  =  kwargs .pop ("revision" , None )
438+         low_cpu_mem_usage  =  kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
439+ 
440+         if  low_cpu_mem_usage  and  not  is_accelerate_available ():
441+             low_cpu_mem_usage  =  False 
442+             logger .warning (
443+                 "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" 
444+                 " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" 
445+                 " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip" 
446+                 " install accelerate\n ```\n ." 
447+             )
448+ 
449+         if  low_cpu_mem_usage  is  True  and  not  is_torch_version (">=" , "1.9.0" ):
450+             raise  NotImplementedError (
451+                 "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" 
452+                 " `low_cpu_mem_usage=False`." 
453+             )
454+ 
455+         user_agent  =  {
456+             "file_type" : "attn_procs_weights" ,
457+             "framework" : "pytorch" ,
458+         }
459+         state_dicts  =  []
460+         for  pretrained_model_name_or_path_or_dict , weight_name , subfolder  in  zip (
461+             pretrained_model_name_or_path_or_dict , weight_name , subfolder 
462+         ):
463+             if  not  isinstance (pretrained_model_name_or_path_or_dict , dict ):
464+                 model_file  =  _get_model_file (
465+                     pretrained_model_name_or_path_or_dict ,
466+                     weights_name = weight_name ,
467+                     cache_dir = cache_dir ,
468+                     force_download = force_download ,
469+                     proxies = proxies ,
470+                     local_files_only = local_files_only ,
471+                     token = token ,
472+                     revision = revision ,
473+                     subfolder = subfolder ,
474+                     user_agent = user_agent ,
475+                 )
476+                 if  weight_name .endswith (".safetensors" ):
477+                     state_dict  =  {"image_proj" : {}, "ip_adapter" : {}}
478+                     with  safe_open (model_file , framework = "pt" , device = "cpu" ) as  f :
479+                         for  key  in  f .keys ():
480+                             if  key .startswith ("image_proj." ):
481+                                 state_dict ["image_proj" ][key .replace ("image_proj." , "" )] =  f .get_tensor (key )
482+                             elif  key .startswith ("ip_adapter." ):
483+                                 state_dict ["ip_adapter" ][key .replace ("ip_adapter." , "" )] =  f .get_tensor (key )
484+                 else :
485+                     state_dict  =  load_state_dict (model_file )
486+             else :
487+                 state_dict  =  pretrained_model_name_or_path_or_dict 
488+ 
489+             keys  =  list (state_dict .keys ())
490+             if  "image_proj"  not  in   keys  and  "ip_adapter"  not  in   keys :
491+                 raise  ValueError ("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict." )
492+ 
493+             state_dicts .append (state_dict )
494+ 
495+             # create feature extractor if it has not been registered to the pipeline yet 
496+             if  hasattr (self , "feature_extractor" ) and  getattr (self , "feature_extractor" , None ) is  None :
497+                 # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224 
498+                 default_clip_size  =  224 
499+                 clip_image_size  =  (
500+                     self .image_encoder .config .image_size  if  self .image_encoder  is  not   None  else  default_clip_size 
501+                 )
502+                 feature_extractor  =  CLIPImageProcessor (size = clip_image_size , crop_size = clip_image_size )
503+ 
504+         unet_name  =  getattr (self , "unet_name" , "unet" )
505+         unet  =  getattr (self , unet_name )
506+         unet ._load_ip_adapter_weights (state_dicts , low_cpu_mem_usage = low_cpu_mem_usage )
507+ 
508+         extra_loras  =  unet ._load_ip_adapter_loras (state_dicts )
509+         if  extra_loras  !=  {}:
510+             if  not  USE_PEFT_BACKEND :
511+                 logger .warning ("PEFT backend is required to load these weights." )
512+             else :
513+                 # apply the IP Adapter Face ID LoRA weights 
514+                 peft_config  =  getattr (unet , "peft_config" , {})
515+                 for  k , lora  in  extra_loras .items ():
516+                     if  f"faceid_{ k }  "  not  in   peft_config :
517+                         self .load_lora_weights (lora , adapter_name = f"faceid_{ k }  " )
518+                         self .set_adapters ([f"faceid_{ k }  " ], adapter_weights = [1.0 ])
519+ 
520+     def  set_ip_adapter_scale (self , scale ):
521+         """ 
522+         Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for 
523+         granular control over each IP-Adapter behavior. A config can be a float or a dictionary. 
524+ 
525+         Example: 
526+ 
527+         ```py 
528+         # To use original IP-Adapter 
529+         scale = 1.0 
530+         pipeline.set_ip_adapter_scale(scale) 
531+ 
532+         # To use style block only 
533+         scale = { 
534+             "up": {"block_0": [0.0, 1.0, 0.0]}, 
535+         } 
536+         pipeline.set_ip_adapter_scale(scale) 
537+ 
538+         # To use style+layout blocks 
539+         scale = { 
540+             "down": {"block_2": [0.0, 1.0]}, 
541+             "up": {"block_0": [0.0, 1.0, 0.0]}, 
542+         } 
543+         pipeline.set_ip_adapter_scale(scale) 
544+ 
545+         # To use style and layout from 2 reference images 
546+         scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] 
547+         pipeline.set_ip_adapter_scale(scales) 
548+         ``` 
549+         """ 
550+         unet_name  =  getattr (self , "unet_name" , "unet" )
551+         unet  =  getattr (self , unet_name )
552+         if  not  isinstance (scale , list ):
553+             scale  =  [scale ]
554+         scale_configs  =  _maybe_expand_lora_scales (unet , scale , default_scale = 0.0 )
555+ 
556+         for  attn_name , attn_processor  in  unet .attn_processors .items ():
557+             if  isinstance (
558+                 attn_processor , (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 , IPAdapterXFormersAttnProcessor )
559+             ):
560+                 if  len (scale_configs ) !=  len (attn_processor .scale ):
561+                     raise  ValueError (
562+                         f"Cannot assign { len (scale_configs )}   scale_configs to " 
563+                         f"{ len (attn_processor .scale )}   IP-Adapter." 
564+                     )
565+                 elif  len (scale_configs ) ==  1 :
566+                     scale_configs  =  scale_configs  *  len (attn_processor .scale )
567+                 for  i , scale_config  in  enumerate (scale_configs ):
568+                     if  isinstance (scale_config , dict ):
569+                         for  k , s  in  scale_config .items ():
570+                             if  attn_name .startswith (k ):
571+                                 attn_processor .scale [i ] =  s 
572+                     else :
573+                         attn_processor .scale [i ] =  scale_config 
574+ 
575+     def  unload_ip_adapter (self ):
576+         """ 
577+         Unloads the IP Adapter weights 
578+ 
579+         Examples: 
580+ 
581+         ```python 
582+         >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. 
583+         >>> pipeline.unload_ip_adapter() 
584+         >>> ... 
585+         ``` 
586+         """ 
587+ 
588+         # remove hidden encoder 
589+         self .unet .encoder_hid_proj  =  None 
590+         self .unet .config .encoder_hid_dim_type  =  None 
591+ 
592+         # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` 
593+         if  hasattr (self .unet , "text_encoder_hid_proj" ) and  self .unet .text_encoder_hid_proj  is  not   None :
594+             self .unet .encoder_hid_proj  =  self .unet .text_encoder_hid_proj 
595+             self .unet .text_encoder_hid_proj  =  None 
596+             self .unet .config .encoder_hid_dim_type  =  "text_proj" 
597+ 
598+         # restore original Unet attention processors layers 
599+         attn_procs  =  {}
600+         for  name , value  in  self .unet .attn_processors .items ():
601+             attn_processor_class  =  (
602+                 AttnProcessor2_0 () if  hasattr (F , "scaled_dot_product_attention" ) else  AttnProcessor ()
603+             )
604+             attn_procs [name ] =  (
605+                 attn_processor_class 
606+                 if  isinstance (
607+                     value , (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 , IPAdapterXFormersAttnProcessor )
608+                 )
609+                 else  value .__class__ ()
610+             )
611+         self .unet .set_attn_processor (attn_procs )
612+ 
357613
358614class  FluxIPAdapterMixin :
359615    """Mixin for handling Flux IP Adapters.""" 
0 commit comments