Skip to content

Commit 12833b1

Browse files
committed
v2
1 parent 84f08d7 commit 12833b1

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,25 +549,41 @@ def load_ip_adapter(
549549
# load ip-adapter into transformer
550550
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
551551

552-
def set_ip_adapter_scale(self, scale):
552+
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
553553
"""
554554
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
555-
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
555+
granular control over each IP-Adapter behavior. A config can be a float or a list.
556+
557+
`float` is converted to list and repeated for the number of blocks and the number of IP adapters.
558+
`List[float]` length match the number of blocks, it is repeated for each IP adapter.
559+
`List[List[float]]` must match the number of IP adapters and each must match the number of blocks.
556560
557561
Example:
558562
559563
```py
560564
# To use original IP-Adapter
561565
scale = 1.0
562566
pipeline.set_ip_adapter_scale(scale)
567+
def LinearStrengthModel(start, finish, size):
568+
return [
569+
(start + (finish - start) * (i / (size - 1))) for i in range(size)
570+
]
571+
572+
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
573+
pipeline.set_ip_adapter_scale(ip_strengths)
563574
```
564575
"""
565576
transformer = self.transformer
566577
if not isinstance(scale, list):
578+
scale = [[scale] * transformer.config.num_layers]
579+
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
580+
if len(scale) != transformer.config.num_layers:
581+
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
567582
scale = [scale]
568583

569584
scale_configs = scale
570585

586+
key_id = 0
571587
for attn_name, attn_processor in transformer.attn_processors.items():
572588
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
573589
if len(scale_configs) != len(attn_processor.scale):
@@ -578,7 +594,8 @@ def set_ip_adapter_scale(self, scale):
578594
elif len(scale_configs) == 1:
579595
scale_configs = scale_configs * len(attn_processor.scale)
580596
for i, scale_config in enumerate(scale_configs):
581-
attn_processor.scale[i] = scale_config
597+
attn_processor.scale[i] = scale_config[key_id]
598+
key_id += 1
582599

583600
def unload_ip_adapter(self):
584601
"""

src/diffusers/loaders/transformer_flux.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
6262

6363
if "proj.weight" in state_dict:
6464
# IP-Adapter
65-
# TODO: fix for XLabs-AI/flux-ip-adapter-v2
6665
num_image_text_embeds = 4
66+
if state_dict["proj.weight"].shape[0] == 65536:
67+
num_image_text_embeds = 16
6768
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
68-
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
69+
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
6970

7071
with init_context():
7172
image_projection = ImageProjection(
@@ -124,9 +125,11 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
124125
num_image_text_embeds = []
125126
for state_dict in state_dicts:
126127
if "proj.weight" in state_dict["image_proj"]:
128+
num_image_text_embed = 4
129+
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
130+
num_image_text_embed = 16
127131
# IP-Adapter
128-
# TODO: change for XLabs-AI/flux-ip-adapter-v2
129-
num_image_text_embeds += [4]
132+
num_image_text_embeds += [num_image_text_embed]
130133

131134
with init_context():
132135
attn_procs[name] = attn_processor_class(

0 commit comments

Comments
 (0)