Skip to content

Commit 2ee946f

Browse files
committed
Flux IP-Adapter
1 parent f9d5a93 commit 2ee946f

File tree

8 files changed

+785
-11
lines changed

8 files changed

+785
-11
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
from accelerate import init_empty_weights
6+
from huggingface_hub import hf_hub_download
7+
8+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
9+
10+
11+
if is_transformers_available():
12+
from transformers import CLIPVisionModelWithProjection
13+
14+
vision = True
15+
else:
16+
vision = False
17+
18+
"""
19+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
20+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
21+
--filename "flux-ip-adapter.safetensors"
22+
--output_path "flux-ip-adapter-hf/"
23+
"""
24+
25+
26+
CTX = init_empty_weights if is_accelerate_available else nullcontext
27+
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
30+
parser.add_argument("--filename", default="flux.safetensors", type=str)
31+
parser.add_argument("--checkpoint_path", default=None, type=str)
32+
parser.add_argument("--output_path", type=str)
33+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
34+
35+
args = parser.parse_args()
36+
37+
38+
def load_original_checkpoint(args):
39+
if args.original_state_dict_repo_id is not None:
40+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
41+
elif args.checkpoint_path is not None:
42+
ckpt_path = args.checkpoint_path
43+
else:
44+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
45+
46+
original_state_dict = safetensors.torch.load_file(ckpt_path)
47+
return original_state_dict
48+
49+
50+
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
51+
converted_state_dict = {}
52+
53+
# image_proj
54+
## norm
55+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
56+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
57+
## proj
58+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
59+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
60+
61+
# double transformer blocks
62+
for i in range(num_layers):
63+
block_prefix = f"ip_adapter.{i}."
64+
# to_k_ip
65+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
66+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
67+
)
68+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
69+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
70+
)
71+
# to_v_ip
72+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
73+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
74+
)
75+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
76+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
77+
)
78+
79+
return converted_state_dict
80+
81+
82+
def main(args):
83+
original_ckpt = load_original_checkpoint(args)
84+
85+
num_layers = 19
86+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
87+
88+
print("Saving Flux IP-Adapter in Diffusers format.")
89+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
90+
91+
if vision:
92+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
93+
model.save_pretrained(f"{args.output_path}/image_encoder")
94+
95+
96+
if __name__ == "__main__":
97+
main(args)

src/diffusers/loaders/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
58+
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
5959
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6060
_import_structure["utils"] = ["AttnProcsLayers"]
6161
if is_transformers_available():
@@ -72,19 +72,20 @@ def text_encoder_attn_modules(text_encoder):
7272
"Mochi1LoraLoaderMixin",
7373
]
7474
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
75-
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
75+
_import_structure["ip_adapter"] = ["IPAdapterMixin", "FluxIPAdapterMixin"]
7676

7777
_import_structure["peft"] = ["PeftAdapterMixin"]
7878

7979

8080
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
8181
if is_torch_available():
8282
from .single_file_model import FromOriginalModelMixin
83+
from .transformer_flux import FluxTransformer2DLoadersMixin
8384
from .unet import UNet2DConditionLoadersMixin
8485
from .utils import AttnProcsLayers
8586

8687
if is_transformers_available():
87-
from .ip_adapter import IPAdapterMixin
88+
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin
8889
from .lora_pipeline import (
8990
AmusedLoraLoaderMixin,
9091
CogVideoXLoraLoaderMixin,

0 commit comments

Comments
 (0)