Skip to content

Commit 68a5185

Browse files
committed
refactor more, ipadapter node, lora node
1 parent 6e2fe26 commit 68a5185

File tree

4 files changed

+1047
-857
lines changed

4 files changed

+1047
-857
lines changed

src/diffusers/guider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def prepare_input(
169169
else:
170170
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
171171

172-
if not self._is_prepared_input(cond_input) and negative_cond_input is None:
172+
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
173173
raise ValueError(
174174
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
175175
)
@@ -447,7 +447,7 @@ def prepare_input(
447447
else:
448448
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
449449

450-
if not self._is_prepared_input(cond_input) and negative_cond_input is None:
450+
if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None:
451451
raise ValueError(
452452
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
453453
)
@@ -688,7 +688,7 @@ def prepare_input(
688688
else:
689689
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
690690

691-
if not self._is_prepared_input(cond_input) and negative_cond_input is None:
691+
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
692692
raise ValueError(
693693
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
694694
)

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def text_encoder_attn_modules(text_encoder):
7979
"IPAdapterMixin",
8080
"FluxIPAdapterMixin",
8181
"SD3IPAdapterMixin",
82+
"ModularIPAdapterMixin",
8283
]
8384

8485
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -97,6 +98,7 @@ def text_encoder_attn_modules(text_encoder):
9798
FluxIPAdapterMixin,
9899
IPAdapterMixin,
99100
SD3IPAdapterMixin,
101+
ModularIPAdapterMixin,
100102
)
101103
from .lora_pipeline import (
102104
AmusedLoraLoaderMixin,

src/diffusers/loaders/ip_adapter.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,262 @@ def unload_ip_adapter(self):
354354
)
355355
self.unet.set_attn_processor(attn_procs)
356356

357+
class ModularIPAdapterMixin:
358+
"""Mixin for handling IP Adapters."""
359+
360+
@validate_hf_hub_args
361+
def load_ip_adapter(
362+
self,
363+
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
364+
subfolder: Union[str, List[str]],
365+
weight_name: Union[str, List[str]],
366+
**kwargs,
367+
):
368+
"""
369+
Parameters:
370+
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
371+
Can be either:
372+
373+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
374+
the Hub.
375+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
376+
with [`ModelMixin.save_pretrained`].
377+
- A [torch state
378+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
379+
subfolder (`str` or `List[str]`):
380+
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
381+
list is passed, it should have the same length as `weight_name`.
382+
weight_name (`str` or `List[str]`):
383+
The name of the weight file to load. If a list is passed, it should have the same length as
384+
`subfolder`.
385+
cache_dir (`Union[str, os.PathLike]`, *optional*):
386+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
387+
is not used.
388+
force_download (`bool`, *optional*, defaults to `False`):
389+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
390+
cached versions if they exist.
391+
392+
proxies (`Dict[str, str]`, *optional*):
393+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
394+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
395+
local_files_only (`bool`, *optional*, defaults to `False`):
396+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
397+
won't be downloaded from the Hub.
398+
token (`str` or *bool*, *optional*):
399+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
400+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
401+
revision (`str`, *optional*, defaults to `"main"`):
402+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
403+
allowed by Git.
404+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
405+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
406+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
407+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
408+
argument to `True` will raise an error.
409+
"""
410+
411+
# handle the list inputs for multiple IP Adapters
412+
if not isinstance(weight_name, list):
413+
weight_name = [weight_name]
414+
415+
if not isinstance(pretrained_model_name_or_path_or_dict, list):
416+
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
417+
if len(pretrained_model_name_or_path_or_dict) == 1:
418+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
419+
420+
if not isinstance(subfolder, list):
421+
subfolder = [subfolder]
422+
if len(subfolder) == 1:
423+
subfolder = subfolder * len(weight_name)
424+
425+
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
426+
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
427+
428+
if len(weight_name) != len(subfolder):
429+
raise ValueError("`weight_name` and `subfolder` must have the same length.")
430+
431+
# Load the main state dict first.
432+
cache_dir = kwargs.pop("cache_dir", None)
433+
force_download = kwargs.pop("force_download", False)
434+
proxies = kwargs.pop("proxies", None)
435+
local_files_only = kwargs.pop("local_files_only", None)
436+
token = kwargs.pop("token", None)
437+
revision = kwargs.pop("revision", None)
438+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
439+
440+
if low_cpu_mem_usage and not is_accelerate_available():
441+
low_cpu_mem_usage = False
442+
logger.warning(
443+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
444+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
445+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
446+
" install accelerate\n```\n."
447+
)
448+
449+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
450+
raise NotImplementedError(
451+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
452+
" `low_cpu_mem_usage=False`."
453+
)
454+
455+
user_agent = {
456+
"file_type": "attn_procs_weights",
457+
"framework": "pytorch",
458+
}
459+
state_dicts = []
460+
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
461+
pretrained_model_name_or_path_or_dict, weight_name, subfolder
462+
):
463+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
464+
model_file = _get_model_file(
465+
pretrained_model_name_or_path_or_dict,
466+
weights_name=weight_name,
467+
cache_dir=cache_dir,
468+
force_download=force_download,
469+
proxies=proxies,
470+
local_files_only=local_files_only,
471+
token=token,
472+
revision=revision,
473+
subfolder=subfolder,
474+
user_agent=user_agent,
475+
)
476+
if weight_name.endswith(".safetensors"):
477+
state_dict = {"image_proj": {}, "ip_adapter": {}}
478+
with safe_open(model_file, framework="pt", device="cpu") as f:
479+
for key in f.keys():
480+
if key.startswith("image_proj."):
481+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
482+
elif key.startswith("ip_adapter."):
483+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
484+
else:
485+
state_dict = load_state_dict(model_file)
486+
else:
487+
state_dict = pretrained_model_name_or_path_or_dict
488+
489+
keys = list(state_dict.keys())
490+
if "image_proj" not in keys and "ip_adapter" not in keys:
491+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
492+
493+
state_dicts.append(state_dict)
494+
495+
# create feature extractor if it has not been registered to the pipeline yet
496+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
497+
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
498+
default_clip_size = 224
499+
clip_image_size = (
500+
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
501+
)
502+
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
503+
504+
unet_name = getattr(self, "unet_name", "unet")
505+
unet = getattr(self, unet_name)
506+
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
507+
508+
extra_loras = unet._load_ip_adapter_loras(state_dicts)
509+
if extra_loras != {}:
510+
if not USE_PEFT_BACKEND:
511+
logger.warning("PEFT backend is required to load these weights.")
512+
else:
513+
# apply the IP Adapter Face ID LoRA weights
514+
peft_config = getattr(unet, "peft_config", {})
515+
for k, lora in extra_loras.items():
516+
if f"faceid_{k}" not in peft_config:
517+
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
518+
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
519+
520+
def set_ip_adapter_scale(self, scale):
521+
"""
522+
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
523+
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
524+
525+
Example:
526+
527+
```py
528+
# To use original IP-Adapter
529+
scale = 1.0
530+
pipeline.set_ip_adapter_scale(scale)
531+
532+
# To use style block only
533+
scale = {
534+
"up": {"block_0": [0.0, 1.0, 0.0]},
535+
}
536+
pipeline.set_ip_adapter_scale(scale)
537+
538+
# To use style+layout blocks
539+
scale = {
540+
"down": {"block_2": [0.0, 1.0]},
541+
"up": {"block_0": [0.0, 1.0, 0.0]},
542+
}
543+
pipeline.set_ip_adapter_scale(scale)
544+
545+
# To use style and layout from 2 reference images
546+
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
547+
pipeline.set_ip_adapter_scale(scales)
548+
```
549+
"""
550+
unet_name = getattr(self, "unet_name", "unet")
551+
unet = getattr(self, unet_name)
552+
if not isinstance(scale, list):
553+
scale = [scale]
554+
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
555+
556+
for attn_name, attn_processor in unet.attn_processors.items():
557+
if isinstance(
558+
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
559+
):
560+
if len(scale_configs) != len(attn_processor.scale):
561+
raise ValueError(
562+
f"Cannot assign {len(scale_configs)} scale_configs to "
563+
f"{len(attn_processor.scale)} IP-Adapter."
564+
)
565+
elif len(scale_configs) == 1:
566+
scale_configs = scale_configs * len(attn_processor.scale)
567+
for i, scale_config in enumerate(scale_configs):
568+
if isinstance(scale_config, dict):
569+
for k, s in scale_config.items():
570+
if attn_name.startswith(k):
571+
attn_processor.scale[i] = s
572+
else:
573+
attn_processor.scale[i] = scale_config
574+
575+
def unload_ip_adapter(self):
576+
"""
577+
Unloads the IP Adapter weights
578+
579+
Examples:
580+
581+
```python
582+
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
583+
>>> pipeline.unload_ip_adapter()
584+
>>> ...
585+
```
586+
"""
587+
588+
# remove hidden encoder
589+
self.unet.encoder_hid_proj = None
590+
self.unet.config.encoder_hid_dim_type = None
591+
592+
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
593+
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
594+
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
595+
self.unet.text_encoder_hid_proj = None
596+
self.unet.config.encoder_hid_dim_type = "text_proj"
597+
598+
# restore original Unet attention processors layers
599+
attn_procs = {}
600+
for name, value in self.unet.attn_processors.items():
601+
attn_processor_class = (
602+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
603+
)
604+
attn_procs[name] = (
605+
attn_processor_class
606+
if isinstance(
607+
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
608+
)
609+
else value.__class__()
610+
)
611+
self.unet.set_attn_processor(attn_procs)
612+
357613

358614
class FluxIPAdapterMixin:
359615
"""Mixin for handling Flux IP Adapters."""

0 commit comments

Comments
 (0)