Skip to content

Commit 42fe3fb

Browse files
committed
Flux IP-Adapter
1 parent c5376c5 commit 42fe3fb

File tree

7 files changed

+730
-7
lines changed

7 files changed

+730
-7
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
import torch
6+
from accelerate import init_empty_weights
7+
from huggingface_hub import hf_hub_download
8+
9+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
10+
11+
if is_transformers_available():
12+
from transformers import CLIPVisionModelWithProjection
13+
vision = True
14+
else:
15+
vision = False
16+
17+
"""
18+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
19+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
20+
--filename "flux-ip-adapter.safetensors"
21+
--output_path "flux-ip-adapter-hf/"
22+
"""
23+
24+
25+
CTX = init_empty_weights if is_accelerate_available else nullcontext
26+
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
29+
parser.add_argument("--filename", default="flux.safetensors", type=str)
30+
parser.add_argument("--checkpoint_path", default=None, type=str)
31+
parser.add_argument("--output_path", type=str)
32+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
33+
34+
args = parser.parse_args()
35+
36+
def load_original_checkpoint(args):
37+
if args.original_state_dict_repo_id is not None:
38+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
39+
elif args.checkpoint_path is not None:
40+
ckpt_path = args.checkpoint_path
41+
else:
42+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
43+
44+
original_state_dict = safetensors.torch.load_file(ckpt_path)
45+
return original_state_dict
46+
47+
def convert_flux_ipadapter_checkpoint_to_diffusers(
48+
original_state_dict, num_layers
49+
):
50+
converted_state_dict = {}
51+
52+
# image_proj
53+
## norm
54+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
55+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
56+
## proj
57+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
58+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
59+
60+
# double transformer blocks
61+
for i in range(num_layers):
62+
block_prefix = f"ip_adapter.{i}."
63+
# to_k_ip
64+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
65+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
66+
)
67+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
68+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
69+
)
70+
# to_v_ip
71+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
72+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
73+
)
74+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
75+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
76+
)
77+
78+
return converted_state_dict
79+
80+
81+
def main(args):
82+
original_ckpt = load_original_checkpoint(args)
83+
84+
num_layers = 19
85+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(
86+
original_ckpt, num_layers
87+
)
88+
89+
print(
90+
f"Saving Flux IP-Adapter in Diffusers format."
91+
)
92+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
93+
94+
if vision:
95+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
96+
model.save_pretrained(f"{args.output_path}/image_encoder")
97+
98+
if __name__ == "__main__":
99+
main(args)

src/diffusers/loaders/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
58+
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
5959
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6060
_import_structure["utils"] = ["AttnProcsLayers"]
6161
if is_transformers_available():
@@ -70,19 +70,20 @@ def text_encoder_attn_modules(text_encoder):
7070
"CogVideoXLoraLoaderMixin",
7171
]
7272
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
73-
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
73+
_import_structure["ip_adapter"] = ["IPAdapterMixin", "FluxIPAdapterMixin"]
7474

7575
_import_structure["peft"] = ["PeftAdapterMixin"]
7676

7777

7878
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
7979
if is_torch_available():
8080
from .single_file_model import FromOriginalModelMixin
81+
from .transformer_flux import FluxTransformer2DLoadersMixin
8182
from .unet import UNet2DConditionLoadersMixin
8283
from .utils import AttnProcsLayers
8384

8485
if is_transformers_available():
85-
from .ip_adapter import IPAdapterMixin
86+
from .ip_adapter import IPAdapterMixin, FluxIPAdapterMixin
8687
from .lora_pipeline import (
8788
AmusedLoraLoaderMixin,
8889
CogVideoXLoraLoaderMixin,

src/diffusers/loaders/ip_adapter.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
AttnProcessor2_0,
4444
IPAdapterAttnProcessor,
4545
IPAdapterAttnProcessor2_0,
46+
FluxIPAdapterAttnProcessor2_0,
47+
FluxAttnProcessor2_0,
4648
)
4749

4850
logger = logging.get_logger(__name__)
@@ -346,3 +348,254 @@ def unload_ip_adapter(self):
346348
else value.__class__()
347349
)
348350
self.unet.set_attn_processor(attn_procs)
351+
352+
353+
class FluxIPAdapterMixin:
354+
"""Mixin for handling Flux IP Adapters."""
355+
356+
@validate_hf_hub_args
357+
def load_ip_adapter(
358+
self,
359+
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
360+
subfolder: Union[str, List[str]],
361+
weight_name: Union[str, List[str]],
362+
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
363+
image_encoder_subfolder: Optional[str] = None,
364+
**kwargs,
365+
):
366+
"""
367+
Parameters:
368+
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
369+
Can be either:
370+
371+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
372+
the Hub.
373+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
374+
with [`ModelMixin.save_pretrained`].
375+
- A [torch state
376+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
377+
subfolder (`str` or `List[str]`):
378+
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
379+
list is passed, it should have the same length as `weight_name`.
380+
weight_name (`str` or `List[str]`):
381+
The name of the weight file to load. If a list is passed, it should have the same length as
382+
`weight_name`.
383+
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
384+
Can be either:
385+
386+
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model hosted on
387+
the Hub.
388+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
389+
with [`ModelMixin.save_pretrained`].
390+
cache_dir (`Union[str, os.PathLike]`, *optional*):
391+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
392+
is not used.
393+
force_download (`bool`, *optional*, defaults to `False`):
394+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
395+
cached versions if they exist.
396+
397+
proxies (`Dict[str, str]`, *optional*):
398+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
399+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
400+
local_files_only (`bool`, *optional*, defaults to `False`):
401+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
402+
won't be downloaded from the Hub.
403+
token (`str` or *bool*, *optional*):
404+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
405+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
406+
revision (`str`, *optional*, defaults to `"main"`):
407+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
408+
allowed by Git.
409+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
410+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
411+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
412+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
413+
argument to `True` will raise an error.
414+
"""
415+
416+
# handle the list inputs for multiple IP Adapters
417+
if not isinstance(weight_name, list):
418+
weight_name = [weight_name]
419+
420+
if not isinstance(pretrained_model_name_or_path_or_dict, list):
421+
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
422+
if len(pretrained_model_name_or_path_or_dict) == 1:
423+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
424+
425+
if not isinstance(subfolder, list):
426+
subfolder = [subfolder]
427+
if len(subfolder) == 1:
428+
subfolder = subfolder * len(weight_name)
429+
430+
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
431+
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
432+
433+
if len(weight_name) != len(subfolder):
434+
raise ValueError("`weight_name` and `subfolder` must have the same length.")
435+
436+
# Load the main state dict first.
437+
cache_dir = kwargs.pop("cache_dir", None)
438+
force_download = kwargs.pop("force_download", False)
439+
proxies = kwargs.pop("proxies", None)
440+
local_files_only = kwargs.pop("local_files_only", None)
441+
token = kwargs.pop("token", None)
442+
revision = kwargs.pop("revision", None)
443+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
444+
445+
if low_cpu_mem_usage and not is_accelerate_available():
446+
low_cpu_mem_usage = False
447+
logger.warning(
448+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
449+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
450+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
451+
" install accelerate\n```\n."
452+
)
453+
454+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
455+
raise NotImplementedError(
456+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
457+
" `low_cpu_mem_usage=False`."
458+
)
459+
460+
user_agent = {
461+
"file_type": "attn_procs_weights",
462+
"framework": "pytorch",
463+
}
464+
state_dicts = []
465+
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
466+
pretrained_model_name_or_path_or_dict, weight_name, subfolder
467+
):
468+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
469+
model_file = _get_model_file(
470+
pretrained_model_name_or_path_or_dict,
471+
weights_name=weight_name,
472+
cache_dir=cache_dir,
473+
force_download=force_download,
474+
proxies=proxies,
475+
local_files_only=local_files_only,
476+
token=token,
477+
revision=revision,
478+
subfolder=subfolder,
479+
user_agent=user_agent,
480+
)
481+
if weight_name.endswith(".safetensors"):
482+
state_dict = {"image_proj": {}, "ip_adapter": {}}
483+
with safe_open(model_file, framework="pt", device="cpu") as f:
484+
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
485+
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
486+
for key in f.keys():
487+
if any([key.startswith(prefix) for prefix in image_proj_keys]):
488+
diffusers_name = ".".join(key.split(".")[1:])
489+
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
490+
elif any([key.startswith(prefix) for prefix in ip_adapter_keys]):
491+
diffusers_name = (
492+
".".join(key.split(".")[1:])
493+
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
494+
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
495+
)
496+
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
497+
else:
498+
state_dict = load_state_dict(model_file)
499+
else:
500+
state_dict = pretrained_model_name_or_path_or_dict
501+
502+
keys = list(state_dict.keys())
503+
if keys != ["image_proj", "ip_adapter"]:
504+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
505+
506+
state_dicts.append(state_dict)
507+
508+
# load CLIP image encoder here if it has not been registered to the pipeline yet
509+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
510+
if image_encoder_pretrained_model_name_or_path is not None:
511+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
512+
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
513+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
514+
image_encoder_pretrained_model_name_or_path,
515+
subfolder=image_encoder_subfolder,
516+
low_cpu_mem_usage=low_cpu_mem_usage,
517+
cache_dir=cache_dir,
518+
local_files_only=local_files_only,
519+
).to(self.device, dtype=self.dtype)
520+
self.register_modules(image_encoder=image_encoder)
521+
else:
522+
raise ValueError(
523+
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
524+
)
525+
else:
526+
logger.warning(
527+
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
528+
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
529+
)
530+
531+
# create feature extractor if it has not been registered to the pipeline yet
532+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
533+
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
534+
default_clip_size = 224
535+
clip_image_size = (
536+
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
537+
)
538+
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
539+
self.register_modules(feature_extractor=feature_extractor)
540+
541+
# load ip-adapter into transformer
542+
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
543+
544+
def set_ip_adapter_scale(self, scale):
545+
"""
546+
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
547+
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
548+
549+
Example:
550+
551+
```py
552+
# To use original IP-Adapter
553+
scale = 1.0
554+
pipeline.set_ip_adapter_scale(scale)
555+
```
556+
"""
557+
transformer = self.transformer
558+
559+
for attn_name, attn_processor in transformer.attn_processors.items():
560+
if isinstance(attn_processor, (FluxIPAdapterAttnProcessor2_0)):
561+
attn_processor.scale = scale
562+
563+
564+
def unload_ip_adapter(self):
565+
"""
566+
Unloads the IP Adapter weights
567+
568+
Examples:
569+
570+
```python
571+
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
572+
>>> pipeline.unload_ip_adapter()
573+
>>> ...
574+
```
575+
"""
576+
# remove CLIP image encoder
577+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
578+
self.image_encoder = None
579+
self.register_to_config(image_encoder=[None, None])
580+
581+
# remove feature extractor only when safety_checker is None as safety_checker uses
582+
# the feature_extractor later
583+
if not hasattr(self, "safety_checker"):
584+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
585+
self.feature_extractor = None
586+
self.register_to_config(feature_extractor=[None, None])
587+
588+
# remove hidden encoder
589+
self.transformer.encoder_hid_proj = None
590+
self.transformer.config.encoder_hid_dim_type = None
591+
592+
# restore original Transformer attention processors layers
593+
attn_procs = {}
594+
for name, value in self.transformer.attn_processors.items():
595+
attn_processor_class = FluxAttnProcessor2_0()
596+
attn_procs[name] = (
597+
attn_processor_class
598+
if isinstance(value, (FluxIPAdapterAttnProcessor2_0))
599+
else value.__class__()
600+
)
601+
self.transformer.set_attn_processor(attn_procs)

0 commit comments

Comments
 (0)