@@ -549,25 +549,41 @@ def load_ip_adapter(
549549 # load ip-adapter into transformer
550550 self .transformer ._load_ip_adapter_weights (state_dicts , low_cpu_mem_usage = low_cpu_mem_usage )
551551
552- def set_ip_adapter_scale (self , scale ):
552+ def set_ip_adapter_scale (self , scale : Union [ float , List [ float ], List [ List [ float ]]] ):
553553 """
554554 Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
555- granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
555+ granular control over each IP-Adapter behavior. A config can be a float or a list.
556+
557+ `float` is converted to list and repeated for the number of blocks and the number of IP adapters.
558+ `List[float]` length match the number of blocks, it is repeated for each IP adapter.
559+ `List[List[float]]` must match the number of IP adapters and each must match the number of blocks.
556560
557561 Example:
558562
559563 ```py
560564 # To use original IP-Adapter
561565 scale = 1.0
562566 pipeline.set_ip_adapter_scale(scale)
567+ def LinearStrengthModel(start, finish, size):
568+ return [
569+ (start + (finish - start) * (i / (size - 1))) for i in range(size)
570+ ]
571+
572+ ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
573+ pipeline.set_ip_adapter_scale(ip_strengths)
563574 ```
564575 """
565576 transformer = self .transformer
566577 if not isinstance (scale , list ):
578+ scale = [[scale ] * transformer .config .num_layers ]
579+ elif isinstance (scale , list ) and isinstance (scale [0 ], int ) or isinstance (scale [0 ], float ):
580+ if len (scale ) != transformer .config .num_layers :
581+ raise ValueError (f"Expected list of { transformer .config .num_layers } scales, got { len (scale )} ." )
567582 scale = [scale ]
568583
569584 scale_configs = scale
570585
586+ key_id = 0
571587 for attn_name , attn_processor in transformer .attn_processors .items ():
572588 if isinstance (attn_processor , (FluxIPAdapterJointAttnProcessor2_0 )):
573589 if len (scale_configs ) != len (attn_processor .scale ):
@@ -578,7 +594,8 @@ def set_ip_adapter_scale(self, scale):
578594 elif len (scale_configs ) == 1 :
579595 scale_configs = scale_configs * len (attn_processor .scale )
580596 for i , scale_config in enumerate (scale_configs ):
581- attn_processor .scale [i ] = scale_config
597+ attn_processor .scale [i ] = scale_config [key_id ]
598+ key_id += 1
582599
583600 def unload_ip_adapter (self ):
584601 """
0 commit comments