Skip to content

Commit 7ed7141

Browse files
authored
Merge branch 'main' into layerwise-upcasting-hook
2 parents 7dc739b + 233dffd commit 7ed7141

File tree

15 files changed

+1165
-16
lines changed

15 files changed

+1165
-16
lines changed

.github/workflows/push_tests_mps.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
shell: arch -arch arm64 bash {0}
4747
run: |
4848
${CONDA_RUN} python -m pip install --upgrade pip uv
49-
${CONDA_RUN} python -m uv pip install -e [quality,test]
49+
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
5050
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
5151
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
5252
${CONDA_RUN} python -m uv pip install transformers --upgrade
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
from accelerate import init_empty_weights
6+
from huggingface_hub import hf_hub_download
7+
8+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
9+
10+
11+
if is_transformers_available():
12+
from transformers import CLIPVisionModelWithProjection
13+
14+
vision = True
15+
else:
16+
vision = False
17+
18+
"""
19+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
20+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
21+
--filename "flux-ip-adapter.safetensors"
22+
--output_path "flux-ip-adapter-hf/"
23+
"""
24+
25+
26+
CTX = init_empty_weights if is_accelerate_available else nullcontext
27+
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
30+
parser.add_argument("--filename", default="flux.safetensors", type=str)
31+
parser.add_argument("--checkpoint_path", default=None, type=str)
32+
parser.add_argument("--output_path", type=str)
33+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
34+
35+
args = parser.parse_args()
36+
37+
38+
def load_original_checkpoint(args):
39+
if args.original_state_dict_repo_id is not None:
40+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
41+
elif args.checkpoint_path is not None:
42+
ckpt_path = args.checkpoint_path
43+
else:
44+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
45+
46+
original_state_dict = safetensors.torch.load_file(ckpt_path)
47+
return original_state_dict
48+
49+
50+
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
51+
converted_state_dict = {}
52+
53+
# image_proj
54+
## norm
55+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
56+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
57+
## proj
58+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
59+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
60+
61+
# double transformer blocks
62+
for i in range(num_layers):
63+
block_prefix = f"ip_adapter.{i}."
64+
# to_k_ip
65+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
66+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
67+
)
68+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
69+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
70+
)
71+
# to_v_ip
72+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
73+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
74+
)
75+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
76+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
77+
)
78+
79+
return converted_state_dict
80+
81+
82+
def main(args):
83+
original_ckpt = load_original_checkpoint(args)
84+
85+
num_layers = 19
86+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
87+
88+
print("Saving Flux IP-Adapter in Diffusers format.")
89+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
90+
91+
if vision:
92+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
93+
model.save_pretrained(f"{args.output_path}/image_encoder")
94+
95+
96+
if __name__ == "__main__":
97+
main(args)

src/diffusers/loaders/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
58+
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
5959
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
6060
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6161
_import_structure["utils"] = ["AttnProcsLayers"]
@@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
7777
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7878
_import_structure["ip_adapter"] = [
7979
"IPAdapterMixin",
80+
"FluxIPAdapterMixin",
8081
"SD3IPAdapterMixin",
8182
]
8283

@@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder):
8687
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
8788
if is_torch_available():
8889
from .single_file_model import FromOriginalModelMixin
90+
from .transformer_flux import FluxTransformer2DLoadersMixin
8991
from .transformer_sd3 import SD3Transformer2DLoadersMixin
9092
from .unet import UNet2DConditionLoadersMixin
9193
from .utils import AttnProcsLayers
9294

9395
if is_transformers_available():
9496
from .ip_adapter import (
97+
FluxIPAdapterMixin,
9598
IPAdapterMixin,
9699
SD3IPAdapterMixin,
97100
)

0 commit comments

Comments
 (0)