3333
3434
3535if  is_transformers_available ():
36-     from  transformers  import  CLIPImageProcessor , CLIPVisionModelWithProjection 
37- 
38-     from  ..models .attention_processor  import  (
39-         AttnProcessor ,
40-         AttnProcessor2_0 ,
41-         IPAdapterAttnProcessor ,
42-         IPAdapterAttnProcessor2_0 ,
43-         IPAdapterXFormersAttnProcessor ,
36+     from  transformers  import  (
37+         CLIPImageProcessor ,
38+         CLIPVisionModelWithProjection ,
39+         SiglipImageProcessor ,
40+         SiglipVisionModel 
4441    )
4542
43+ from  ..models .attention_processor  import  (
44+     AttnProcessor ,
45+     AttnProcessor2_0 ,
46+     JointAttnProcessor2_0 ,
47+     IPAdapterAttnProcessor ,
48+     IPAdapterAttnProcessor2_0 ,
49+     IPAdapterXFormersAttnProcessor ,
50+     IPAdapterJointAttnProcessor2_0 ,
51+ )
52+ 
4653logger  =  logging .get_logger (__name__ )
4754
4855
@@ -348,3 +355,212 @@ def unload_ip_adapter(self):
348355                else  value .__class__ ()
349356            )
350357        self .unet .set_attn_processor (attn_procs )
358+ 
359+ 
360+ class  SD3IPAdapterMixin :
361+     """Mixin for handling StableDiffusion 3 IP Adapters.""" 
362+ 
363+     @validate_hf_hub_args  
364+     def  load_ip_adapter (
365+         self ,
366+         pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
367+         subfolder : str ,
368+         weight_name : str ,
369+         image_encoder_folder : Optional [str ] =  "image_encoder" ,
370+         ** kwargs ,
371+     ):
372+         """ 
373+         Parameters: 
374+             pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 
375+                 Can be either: 
376+                     - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on 
377+                       the Hub. 
378+                     - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved 
379+                       with [`ModelMixin.save_pretrained`]. 
380+                     - A [torch state 
381+                       dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). 
382+             subfolder (`str`): 
383+                 The subfolder location of a model file within a larger model repository on the Hub or locally. If a 
384+                 list is passed, it should have the same length as `weight_name`. 
385+             weight_name (`str`): 
386+                 The name of the weight file to load. If a list is passed, it should have the same length as 
387+                 `subfolder`. 
388+             image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): 
389+                 The subfolder location of the image encoder within a larger model repository on the Hub or locally. 
390+                 Pass `None` to not load the image encoder. If the image encoder is located in a folder inside 
391+                 `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. 
392+                 `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than 
393+                 `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, 
394+                 `image_encoder_folder="different_subfolder/image_encoder"`. 
395+             cache_dir (`Union[str, os.PathLike]`, *optional*): 
396+                 Path to a directory where a downloaded pretrained model configuration is cached if the standard cache 
397+                 is not used. 
398+             force_download (`bool`, *optional*, defaults to `False`): 
399+                 Whether or not to force the (re-)download of the model weights and configuration files, overriding the 
400+                 cached versions if they exist. 
401+             proxies (`Dict[str, str]`, *optional*): 
402+                 A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 
403+                 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 
404+             local_files_only (`bool`, *optional*, defaults to `False`): 
405+                 Whether to only load local model weights and configuration files or not. If set to `True`, the model 
406+                 won't be downloaded from the Hub. 
407+             token (`str` or *bool*, *optional*): 
408+                 The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from 
409+                 `diffusers-cli login` (stored in `~/.huggingface`) is used. 
410+             revision (`str`, *optional*, defaults to `"main"`): 
411+                 The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier 
412+                 allowed by Git. 
413+             low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): 
414+                 Speed up model loading only loading the pretrained weights and not initializing the weights. This also 
415+                 tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. 
416+                 Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this 
417+                 argument to `True` will raise an error. 
418+         """ 
419+         # Load the main state dict first 
420+         cache_dir  =  kwargs .pop ("cache_dir" , None )
421+         force_download  =  kwargs .pop ("force_download" , False )
422+         proxies  =  kwargs .pop ("proxies" , None )
423+         local_files_only  =  kwargs .pop ("local_files_only" , None )
424+         token  =  kwargs .pop ("token" , None )
425+         revision  =  kwargs .pop ("revision" , None )
426+         low_cpu_mem_usage  =  kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
427+ 
428+         if  low_cpu_mem_usage  and  not  is_accelerate_available ():
429+             low_cpu_mem_usage  =  False 
430+             logger .warning (
431+                 "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" 
432+                 " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" 
433+                 " `accelerate` for faster and less memory-intense model loading. You can do so with: \n ```\n pip" 
434+                 " install accelerate\n ```\n ." 
435+             )
436+ 
437+         if  low_cpu_mem_usage  is  True  and  not  is_torch_version (">=" , "1.9.0" ):
438+             raise  NotImplementedError (
439+                 "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" 
440+                 " `low_cpu_mem_usage=False`." 
441+             )
442+ 
443+         user_agent  =  {
444+             "file_type" : "attn_procs_weights" ,
445+             "framework" : "pytorch" ,
446+         }
447+ 
448+         if  not  isinstance (pretrained_model_name_or_path_or_dict , dict ):
449+             model_file  =  _get_model_file (
450+                 pretrained_model_name_or_path_or_dict ,
451+                 weights_name = weight_name ,
452+                 cache_dir = cache_dir ,
453+                 force_download = force_download ,
454+                 proxies = proxies ,
455+                 local_files_only = local_files_only ,
456+                 token = token ,
457+                 revision = revision ,
458+                 subfolder = subfolder ,
459+                 user_agent = user_agent ,
460+             )
461+             if  weight_name .endswith (".safetensors" ):
462+                 state_dict  =  {"image_proj" : {}, "ip_adapter" : {}}
463+                 with  safe_open (model_file , framework = "pt" , device = "cpu" ) as  f :
464+                     for  key  in  f .keys ():
465+                         if  key .startswith ("image_proj." ):
466+                             state_dict ["image_proj" ][key .replace ("image_proj." , "" )] =  f .get_tensor (key )
467+                         elif  key .startswith ("ip_adapter." ):
468+                             state_dict ["ip_adapter" ][key .replace ("ip_adapter." , "" )] =  f .get_tensor (key )
469+             else :
470+                 state_dict  =  load_state_dict (model_file )
471+         else :
472+             state_dict  =  pretrained_model_name_or_path_or_dict 
473+ 
474+         keys  =  list (state_dict .keys ())
475+         if  "image_proj"  not  in keys  and  "ip_adapter"  not  in keys :
476+             raise  ValueError ("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict." )
477+ 
478+         # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet 
479+         if  hasattr (self , "image_encoder" ) and  getattr (self , "image_encoder" , None ) is  None :
480+             if  image_encoder_folder  is  not None :
481+                 if  not  isinstance (pretrained_model_name_or_path_or_dict , dict ):
482+                     logger .info (f"loading image_encoder from { pretrained_model_name_or_path_or_dict }  )
483+                     if  image_encoder_folder .count ("/" ) ==  0 :
484+                         image_encoder_subfolder  =  Path (subfolder , image_encoder_folder ).as_posix ()
485+                     else :
486+                         image_encoder_subfolder  =  Path (image_encoder_folder ).as_posix ()
487+ 
488+                     # Commons args for loading image encoder and image processor 
489+                     args  =  dict (
490+                         pretrained_model_name_or_path_or_dict ,
491+                         subfolder = image_encoder_subfolder ,
492+                         low_cpu_mem_usage = low_cpu_mem_usage ,
493+                         cache_dir = cache_dir ,
494+                         local_files_only = local_files_only ,
495+                     )
496+ 
497+                     self .register_modules (
498+                         feature_extractor  =  SiglipImageProcessor .from_pretrained (** args ).to (self .device , dtype = self .dtype ),
499+                         image_encoder  =  SiglipVisionModel .from_pretrained (** args ).to (self .device , dtype = self .dtype ),
500+                     )
501+                 else :
502+                     raise  ValueError (
503+                         "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." 
504+                     )
505+             else :
506+                 logger .warning (
507+                     "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." 
508+                     "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." 
509+                 )
510+ 
511+         # Load IP-Adapter into transformer 
512+         self .transformer ._load_ip_adapter_weights (state_dict , low_cpu_mem_usage = low_cpu_mem_usage )
513+ 
514+     def  set_ip_adapter_scale (self , scale : float ):
515+         """ 
516+         Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0 
517+         only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they  
518+         may not be as aligned with the image prompt. 
519+ 
520+         Example: 
521+ 
522+         ```python 
523+         >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. 
524+         >>> pipeline.set_ip_adapter_scale(0.6) 
525+         >>> ... 
526+         ``` 
527+         """ 
528+         for  attn_processor  in  self .transformer .attn_processors .values ():
529+             if  isinstance (attn_processor , IPAdapterJointAttnProcessor2_0 ):
530+                 attn_processor .scale  =  scale 
531+ 
532+     def  unload_ip_adapter (self ):
533+         """ 
534+         Unloads the IP Adapter weights. 
535+ 
536+         Example: 
537+ 
538+         ```python 
539+         >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. 
540+         >>> pipeline.unload_ip_adapter() 
541+         >>> ... 
542+         ``` 
543+         """ 
544+         # Remove image encoder 
545+         if  hasattr (self , "image_encoder" ) and  getattr (self , "image_encoder" , None ) is  not None :
546+             self .image_encoder  =  None 
547+             self .register_to_config (image_encoder = None )
548+ 
549+         # Remove feature extractor 
550+         if  hasattr (self , "feature_extractor" ) and  getattr (self , "feature_extractor" , None ) is  not None :
551+             self .feature_extractor  =  None 
552+             self .register_to_config (feature_extractor = None )
553+ 
554+         # Remove image projection 
555+         self .transformer .image_proj  =  None 
556+ 
557+         # Restore original attention processors layers 
558+         attn_procs  =  {
559+             name : (
560+                 JointAttnProcessor2_0 ()
561+                 if  isinstance (value , IPAdapterJointAttnProcessor2_0 )
562+                 else  value .__class__ ()
563+             )
564+             for  name , value  in  self .transformer .attn_processors .items ()
565+         }
566+         self .transformer .set_attn_processor (attn_procs )
0 commit comments