Skip to content

Commit 6a6636c

Browse files
committed
Flux IP-Adapter
1 parent c5376c5 commit 6a6636c

File tree

7 files changed

+706
-7
lines changed

7 files changed

+706
-7
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
from accelerate import init_empty_weights
6+
from huggingface_hub import hf_hub_download
7+
8+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
9+
10+
11+
if is_transformers_available():
12+
from transformers import CLIPVisionModelWithProjection
13+
14+
vision = True
15+
else:
16+
vision = False
17+
18+
"""
19+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
20+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
21+
--filename "flux-ip-adapter.safetensors"
22+
--output_path "flux-ip-adapter-hf/"
23+
"""
24+
25+
26+
CTX = init_empty_weights if is_accelerate_available else nullcontext
27+
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
30+
parser.add_argument("--filename", default="flux.safetensors", type=str)
31+
parser.add_argument("--checkpoint_path", default=None, type=str)
32+
parser.add_argument("--output_path", type=str)
33+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
34+
35+
args = parser.parse_args()
36+
37+
38+
def load_original_checkpoint(args):
39+
if args.original_state_dict_repo_id is not None:
40+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
41+
elif args.checkpoint_path is not None:
42+
ckpt_path = args.checkpoint_path
43+
else:
44+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
45+
46+
original_state_dict = safetensors.torch.load_file(ckpt_path)
47+
return original_state_dict
48+
49+
50+
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
51+
converted_state_dict = {}
52+
53+
# image_proj
54+
## norm
55+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
56+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
57+
## proj
58+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
59+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
60+
61+
# double transformer blocks
62+
for i in range(num_layers):
63+
block_prefix = f"ip_adapter.{i}."
64+
# to_k_ip
65+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
66+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
67+
)
68+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
69+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
70+
)
71+
# to_v_ip
72+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
73+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
74+
)
75+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
76+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
77+
)
78+
79+
return converted_state_dict
80+
81+
82+
def main(args):
83+
original_ckpt = load_original_checkpoint(args)
84+
85+
num_layers = 19
86+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
87+
88+
print("Saving Flux IP-Adapter in Diffusers format.")
89+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
90+
91+
if vision:
92+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
93+
model.save_pretrained(f"{args.output_path}/image_encoder")
94+
95+
96+
if __name__ == "__main__":
97+
main(args)

src/diffusers/loaders/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
58+
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
5959
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6060
_import_structure["utils"] = ["AttnProcsLayers"]
6161
if is_transformers_available():
@@ -70,19 +70,20 @@ def text_encoder_attn_modules(text_encoder):
7070
"CogVideoXLoraLoaderMixin",
7171
]
7272
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
73-
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
73+
_import_structure["ip_adapter"] = ["IPAdapterMixin", "FluxIPAdapterMixin"]
7474

7575
_import_structure["peft"] = ["PeftAdapterMixin"]
7676

7777

7878
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
7979
if is_torch_available():
8080
from .single_file_model import FromOriginalModelMixin
81+
from .transformer_flux import FluxTransformer2DLoadersMixin
8182
from .unet import UNet2DConditionLoadersMixin
8283
from .utils import AttnProcsLayers
8384

8485
if is_transformers_available():
85-
from .ip_adapter import IPAdapterMixin
86+
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin
8687
from .lora_pipeline import (
8788
AmusedLoraLoaderMixin,
8889
CogVideoXLoraLoaderMixin,

src/diffusers/loaders/ip_adapter.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from ..models.attention_processor import (
4242
AttnProcessor,
4343
AttnProcessor2_0,
44+
FluxAttnProcessor2_0,
45+
FluxIPAdapterAttnProcessor2_0,
4446
IPAdapterAttnProcessor,
4547
IPAdapterAttnProcessor2_0,
4648
)
@@ -346,3 +348,251 @@ def unload_ip_adapter(self):
346348
else value.__class__()
347349
)
348350
self.unet.set_attn_processor(attn_procs)
351+
352+
353+
class FluxIPAdapterMixin:
354+
"""Mixin for handling Flux IP Adapters."""
355+
356+
@validate_hf_hub_args
357+
def load_ip_adapter(
358+
self,
359+
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
360+
subfolder: Union[str, List[str]],
361+
weight_name: Union[str, List[str]],
362+
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
363+
image_encoder_subfolder: Optional[str] = None,
364+
**kwargs,
365+
):
366+
"""
367+
Parameters:
368+
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
369+
Can be either:
370+
371+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
372+
the Hub.
373+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
374+
with [`ModelMixin.save_pretrained`].
375+
- A [torch state
376+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
377+
subfolder (`str` or `List[str]`):
378+
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
379+
list is passed, it should have the same length as `weight_name`.
380+
weight_name (`str` or `List[str]`):
381+
The name of the weight file to load. If a list is passed, it should have the same length as
382+
`weight_name`.
383+
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
384+
Can be either:
385+
386+
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model hosted on
387+
the Hub.
388+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
389+
with [`ModelMixin.save_pretrained`].
390+
cache_dir (`Union[str, os.PathLike]`, *optional*):
391+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
392+
is not used.
393+
force_download (`bool`, *optional*, defaults to `False`):
394+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
395+
cached versions if they exist.
396+
397+
proxies (`Dict[str, str]`, *optional*):
398+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
399+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
400+
local_files_only (`bool`, *optional*, defaults to `False`):
401+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
402+
won't be downloaded from the Hub.
403+
token (`str` or *bool*, *optional*):
404+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
405+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
406+
revision (`str`, *optional*, defaults to `"main"`):
407+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
408+
allowed by Git.
409+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
410+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
411+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
412+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
413+
argument to `True` will raise an error.
414+
"""
415+
416+
# handle the list inputs for multiple IP Adapters
417+
if not isinstance(weight_name, list):
418+
weight_name = [weight_name]
419+
420+
if not isinstance(pretrained_model_name_or_path_or_dict, list):
421+
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
422+
if len(pretrained_model_name_or_path_or_dict) == 1:
423+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
424+
425+
if not isinstance(subfolder, list):
426+
subfolder = [subfolder]
427+
if len(subfolder) == 1:
428+
subfolder = subfolder * len(weight_name)
429+
430+
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
431+
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
432+
433+
if len(weight_name) != len(subfolder):
434+
raise ValueError("`weight_name` and `subfolder` must have the same length.")
435+
436+
# Load the main state dict first.
437+
cache_dir = kwargs.pop("cache_dir", None)
438+
force_download = kwargs.pop("force_download", False)
439+
proxies = kwargs.pop("proxies", None)
440+
local_files_only = kwargs.pop("local_files_only", None)
441+
token = kwargs.pop("token", None)
442+
revision = kwargs.pop("revision", None)
443+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
444+
445+
if low_cpu_mem_usage and not is_accelerate_available():
446+
low_cpu_mem_usage = False
447+
logger.warning(
448+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
449+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
450+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
451+
" install accelerate\n```\n."
452+
)
453+
454+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
455+
raise NotImplementedError(
456+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
457+
" `low_cpu_mem_usage=False`."
458+
)
459+
460+
user_agent = {
461+
"file_type": "attn_procs_weights",
462+
"framework": "pytorch",
463+
}
464+
state_dicts = []
465+
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
466+
pretrained_model_name_or_path_or_dict, weight_name, subfolder
467+
):
468+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
469+
model_file = _get_model_file(
470+
pretrained_model_name_or_path_or_dict,
471+
weights_name=weight_name,
472+
cache_dir=cache_dir,
473+
force_download=force_download,
474+
proxies=proxies,
475+
local_files_only=local_files_only,
476+
token=token,
477+
revision=revision,
478+
subfolder=subfolder,
479+
user_agent=user_agent,
480+
)
481+
if weight_name.endswith(".safetensors"):
482+
state_dict = {"image_proj": {}, "ip_adapter": {}}
483+
with safe_open(model_file, framework="pt", device="cpu") as f:
484+
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
485+
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
486+
for key in f.keys():
487+
if any(key.startswith(prefix) for prefix in image_proj_keys):
488+
diffusers_name = ".".join(key.split(".")[1:])
489+
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
490+
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
491+
diffusers_name = (
492+
".".join(key.split(".")[1:])
493+
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
494+
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
495+
)
496+
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
497+
else:
498+
state_dict = load_state_dict(model_file)
499+
else:
500+
state_dict = pretrained_model_name_or_path_or_dict
501+
502+
keys = list(state_dict.keys())
503+
if keys != ["image_proj", "ip_adapter"]:
504+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
505+
506+
state_dicts.append(state_dict)
507+
508+
# load CLIP image encoder here if it has not been registered to the pipeline yet
509+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
510+
if image_encoder_pretrained_model_name_or_path is not None:
511+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
512+
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
513+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
514+
image_encoder_pretrained_model_name_or_path,
515+
subfolder=image_encoder_subfolder,
516+
low_cpu_mem_usage=low_cpu_mem_usage,
517+
cache_dir=cache_dir,
518+
local_files_only=local_files_only,
519+
).to(self.device, dtype=self.dtype)
520+
self.register_modules(image_encoder=image_encoder)
521+
else:
522+
raise ValueError(
523+
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
524+
)
525+
else:
526+
logger.warning(
527+
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
528+
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
529+
)
530+
531+
# create feature extractor if it has not been registered to the pipeline yet
532+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
533+
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
534+
default_clip_size = 224
535+
clip_image_size = (
536+
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
537+
)
538+
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
539+
self.register_modules(feature_extractor=feature_extractor)
540+
541+
# load ip-adapter into transformer
542+
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
543+
544+
def set_ip_adapter_scale(self, scale):
545+
"""
546+
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
547+
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
548+
549+
Example:
550+
551+
```py
552+
# To use original IP-Adapter
553+
scale = 1.0
554+
pipeline.set_ip_adapter_scale(scale)
555+
```
556+
"""
557+
transformer = self.transformer
558+
559+
for attn_name, attn_processor in transformer.attn_processors.items():
560+
if isinstance(attn_processor, (FluxIPAdapterAttnProcessor2_0)):
561+
attn_processor.scale = scale
562+
563+
def unload_ip_adapter(self):
564+
"""
565+
Unloads the IP Adapter weights
566+
567+
Examples:
568+
569+
```python
570+
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
571+
>>> pipeline.unload_ip_adapter()
572+
>>> ...
573+
```
574+
"""
575+
# remove CLIP image encoder
576+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
577+
self.image_encoder = None
578+
self.register_to_config(image_encoder=[None, None])
579+
580+
# remove feature extractor only when safety_checker is None as safety_checker uses
581+
# the feature_extractor later
582+
if not hasattr(self, "safety_checker"):
583+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
584+
self.feature_extractor = None
585+
self.register_to_config(feature_extractor=[None, None])
586+
587+
# remove hidden encoder
588+
self.transformer.encoder_hid_proj = None
589+
self.transformer.config.encoder_hid_dim_type = None
590+
591+
# restore original Transformer attention processors layers
592+
attn_procs = {}
593+
for name, value in self.transformer.attn_processors.items():
594+
attn_processor_class = FluxAttnProcessor2_0()
595+
attn_procs[name] = (
596+
attn_processor_class if isinstance(value, (FluxIPAdapterAttnProcessor2_0)) else value.__class__()
597+
)
598+
self.transformer.set_attn_processor(attn_procs)

0 commit comments

Comments
 (0)