@@ -549,25 +549,41 @@ def load_ip_adapter(
549549        # load ip-adapter into transformer 
550550        self .transformer ._load_ip_adapter_weights (state_dicts , low_cpu_mem_usage = low_cpu_mem_usage )
551551
552-     def  set_ip_adapter_scale (self , scale ):
552+     def  set_ip_adapter_scale (self , scale :  Union [ float ,  List [ float ],  List [ List [ float ]]] ):
553553        """ 
554554        Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for 
555-         granular control over each IP-Adapter behavior. A config can be a float or a dictionary. 
555+         granular control over each IP-Adapter behavior. A config can be a float or a list. 
556+ 
557+         `float` is converted to list and repeated for the number of blocks and the number of IP adapters. 
558+         `List[float]` length match the number of blocks, it is repeated for each IP adapter. 
559+         `List[List[float]]` must match the number of IP adapters and each must match the number of blocks. 
556560
557561        Example: 
558562
559563        ```py 
560564        # To use original IP-Adapter 
561565        scale = 1.0 
562566        pipeline.set_ip_adapter_scale(scale) 
567+         def LinearStrengthModel(start, finish, size): 
568+             return [ 
569+                 (start + (finish - start) * (i / (size - 1))) for i in range(size) 
570+             ] 
571+ 
572+         ip_strengths = LinearStrengthModel(0.3, 0.92, 19) 
573+         pipeline.set_ip_adapter_scale(ip_strengths) 
563574        ``` 
564575        """ 
565576        transformer  =  self .transformer 
566577        if  not  isinstance (scale , list ):
578+             scale  =  [[scale ] *  transformer .config .num_layers ]
579+         elif  isinstance (scale , list ) and  isinstance (scale [0 ], int ) or  isinstance (scale [0 ], float ):
580+             if  len (scale ) !=  transformer .config .num_layers :
581+                 raise  ValueError (f"Expected list of { transformer .config .num_layers } { len (scale )}  )
567582            scale  =  [scale ]
568583
569584        scale_configs  =  scale 
570585
586+         key_id  =  0 
571587        for  attn_name , attn_processor  in  transformer .attn_processors .items ():
572588            if  isinstance (attn_processor , (FluxIPAdapterJointAttnProcessor2_0 )):
573589                if  len (scale_configs ) !=  len (attn_processor .scale ):
@@ -578,7 +594,8 @@ def set_ip_adapter_scale(self, scale):
578594                elif  len (scale_configs ) ==  1 :
579595                    scale_configs  =  scale_configs  *  len (attn_processor .scale )
580596                for  i , scale_config  in  enumerate (scale_configs ):
581-                     attn_processor .scale [i ] =  scale_config 
597+                     attn_processor .scale [i ] =  scale_config [key_id ]
598+                 key_id  +=  1 
582599
583600    def  unload_ip_adapter (self ):
584601        """ 
0 commit comments