From a71cfce598853cad0b6bff69c0e1a4fca00e432f Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Sat, 22 Feb 2025 00:04:53 +0000 Subject: [PATCH 1/9] Initial implementation of Flux multi IP-Adapter --- .../pipeline_flux_semantic_guidance.py | 4 +- src/diffusers/loaders/ip_adapter.py | 49 +++++++----- src/diffusers/models/attention_processor.py | 15 ++-- src/diffusers/models/embeddings.py | 5 ++ src/diffusers/pipelines/flux/pipeline_flux.py | 17 ++++- .../pipelines/pipeline_loading_utils.py | 75 ++++++++++++++++++- 6 files changed, 131 insertions(+), 34 deletions(-) diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py index 919e0ad46bd1..972d8c73f7d6 100644 --- a/examples/community/pipeline_flux_semantic_guidance.py +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -537,9 +537,9 @@ def prepare_ip_adapter_image_embeds( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 7b691d1fe16e..224f6374fd01 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -577,29 +577,38 @@ def LinearStrengthModel(start, finish, size): pipeline.set_ip_adapter_scale(ip_strengths) ``` """ - transformer = self.transformer - if not isinstance(scale, list): - scale = [[scale] * transformer.config.num_layers] - elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float): - if len(scale) != transformer.config.num_layers: - raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.") + + from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type + + scale_type = Union[int, float] + num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters + num_layers = self.transformer.config.num_layers + + # Single value for all layers of all IP-Adapters + if isinstance(scale, scale_type): + scale = [scale for _ in range(num_ip_adapters)] + # List of per-layer scales for a single IP-Adapter + elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1: scale = [scale] + # Invalid scale type + elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]): + raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.") - scale_configs = scale + if len(scale) != num_ip_adapters: + raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.") - key_id = 0 - for attn_name, attn_processor in transformer.attn_processors.items(): - if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)): - if len(scale_configs) != len(attn_processor.scale): - raise ValueError( - f"Cannot assign {len(scale_configs)} scale_configs to " - f"{len(attn_processor.scale)} IP-Adapter." - ) - elif len(scale_configs) == 1: - scale_configs = scale_configs * len(attn_processor.scale) - for i, scale_config in enumerate(scale_configs): - attn_processor.scale[i] = scale_config[key_id] - key_id += 1 + if any(len(s) != num_layers for s in scale if isinstance(s, list)): + invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers} + raise ValueError( + f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}." + ) + + # Scalars are transformed to lists with length num_layers + scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale] + + # Set scales. zip over scale_configs prevents going into single transformer layers + for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs): + attn_processor.scale = scale def unload_ip_adapter(self): """ diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8bba5a82bc2f..1a23eea4bb58 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2778,9 +2778,8 @@ def __call__( # IP-adapter ip_query = hidden_states_query_proj - ip_attn_output = None - # for ip-adapter - # TODO: support for multiple adapters + ip_attn_output = torch.zeros_like(hidden_states) + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip ): @@ -2791,12 +2790,14 @@ def __call__( ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - ip_attn_output = F.scaled_dot_product_attention( + current_ip_hidden_states = F.scaled_dot_product_attention( ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_attn_output = scale * ip_attn_output - ip_attn_output = ip_attn_output.to(ip_query.dtype) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states return hidden_states, encoder_hidden_states, ip_attn_output else: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 390b752abe15..04a0b273f1fa 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2583,6 +2583,11 @@ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[ super().__init__() self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + @property + def num_ip_adapters(self) -> int: + """Number of IP-Adapters loaded.""" + return len(self.image_projection_layers) + def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds = [] diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 9f4788a4981a..46ea097243f2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -405,9 +405,9 @@ def prepare_ip_adapter_image_embeds( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( @@ -868,14 +868,23 @@ def __call__( else: guidance = None + # TODO: Clarify this section if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): - negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = ( + [np.zeros((width, height, 3), dtype=np.uint8) for _ in range(len(ip_adapter_image))] + if isinstance(ip_adapter_image, list) + else np.zeros((width, height, 3), dtype=np.uint8) + ) elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): - ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = ( + [np.zeros((width, height, 3), dtype=np.uint8) for _ in range(len(negative_ip_adapter_image))] + if isinstance(negative_ip_adapter_image, list) + else np.zeros((width, height, 3), dtype=np.uint8) + ) if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9a9afa198b4c..0e2cbb32d3c1 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -17,7 +17,7 @@ import re import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin import requests import torch @@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): break if has_transformers_component and not is_transformers_version(">", "4.47.1"): raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") + + +def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: + """ + Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of + the correct type as well. + """ + if not isinstance(class_or_tuple, tuple): + class_or_tuple = (class_or_tuple,) + + # Unpack unions + unpacked_class_or_tuple = [] + for t in class_or_tuple: + if get_origin(t) is Union: + unpacked_class_or_tuple.extend(get_args(t)) + else: + unpacked_class_or_tuple.append(t) + class_or_tuple = tuple(unpacked_class_or_tuple) + + if Any in class_or_tuple: + return True + + obj_type = type(obj) + # Classes with obj's type + class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} + + # Singular types (e.g. int, ControlNet, ...) + # Untyped collections (e.g. List, but not List[int]) + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} + if () in elem_class_or_tuple: + return True + # Typed lists or sets + elif obj_type in (list, set): + return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) + # Typed tuples + elif obj_type is tuple: + return any( + # Tuples with any length and single type (e.g. Tuple[int, ...]) + (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) + or + # Tuples with fixed length and any types (e.g. Tuple[int, str]) + (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) + for t in elem_class_or_tuple + ) + # Typed dicts + elif obj_type is dict: + return any( + all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) + for kt, vt in elem_class_or_tuple + ) + + else: + return False + + +def _get_detailed_type(obj: Any) -> Type: + """ + Gets a detailed type for an object, including nested types for collections. + """ + obj_type = type(obj) + + if obj_type in (list, set): + obj_origin_type = List if obj_type is list else Set + elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] + return obj_origin_type[elems_type] + elif obj_type is tuple: + return Tuple[tuple(_get_detailed_type(x) for x in obj)] + elif obj_type is dict: + keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] + values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] + return Dict[keys_type, values_type] + else: + return obj_type From 79d1617e55ad395aabb9960961717ebb90a46ea6 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Sat, 22 Feb 2025 18:47:33 +0000 Subject: [PATCH 2/9] Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky --- src/diffusers/pipelines/flux/pipeline_flux.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 46ea097243f2..4a5ea5058461 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -872,11 +872,9 @@ def __call__( if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): - negative_ip_adapter_image = ( - [np.zeros((width, height, 3), dtype=np.uint8) for _ in range(len(ip_adapter_image))] - if isinstance(ip_adapter_image, list) - else np.zeros((width, height, 3), dtype=np.uint8) - ) + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + if isinstance(ip_adapter_image, list): + negative_ip_adapter_image = [negative_ip_adapter_image] * len(ip_adapter_image) elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): From 339c0b4e3f8c09b5b4252fd3e47e305bd3ed0e90 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Sat, 22 Feb 2025 18:47:39 +0000 Subject: [PATCH 3/9] Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky --- src/diffusers/pipelines/flux/pipeline_flux.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 4a5ea5058461..7b85a0e0473e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -878,11 +878,9 @@ def __call__( elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): - ip_adapter_image = ( - [np.zeros((width, height, 3), dtype=np.uint8) for _ in range(len(negative_ip_adapter_image))] - if isinstance(negative_ip_adapter_image, list) - else np.zeros((width, height, 3), dtype=np.uint8) - ) + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + if isinstance(negative_ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] * len(negative_ip_adapter_image) if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} From 9e9b0f858e6a0baf7d9d5e16f6a04941b676912f Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Sun, 23 Feb 2025 11:35:06 +0000 Subject: [PATCH 4/9] Changes for ipa image embeds --- .../pipeline_flux_semantic_guidance.py | 4 +-- src/diffusers/pipelines/flux/pipeline_flux.py | 29 +++++++++++-------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py index 972d8c73f7d6..919e0ad46bd1 100644 --- a/examples/community/pipeline_flux_semantic_guidance.py +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -537,9 +537,9 @@ def prepare_ip_adapter_image_embeds( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 7b85a0e0473e..2771ac09119a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -410,18 +410,23 @@ def prepare_ip_adapter_image_embeds( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers - ): + for single_ip_adapter_image in ip_adapter_image: single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) - image_embeds.append(single_image_embeds[None, :]) else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + for single_image_embeds in ip_adapter_image_embeds: image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): + for single_image_embeds in image_embeds: single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) @@ -868,19 +873,19 @@ def __call__( else: guidance = None - # TODO: Clarify this section if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): - negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) - if isinstance(ip_adapter_image, list): - negative_ip_adapter_image = [negative_ip_adapter_image] * len(ip_adapter_image) + zeros_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :] + negative_ip_adapter_image_embeds = [negative_ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): - ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) - if isinstance(negative_ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] * len(negative_ip_adapter_image) + zeros_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :] + ip_adapter_image_embeds = [ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} From cd6f48a701fcb1d55e8a21c1b9bd58d6ad87e9c8 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Mon, 24 Feb 2025 12:06:31 +0000 Subject: [PATCH 5/9] Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 2771ac09119a..fcf1e2713191 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -883,9 +883,8 @@ def __call__( elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): - zeros_image = np.zeros((width, height, 3), dtype=np.uint8) - ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :] - ip_adapter_image_embeds = [ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} From 20840d1b843627b26bad6c686c0b508c5b03a2b1 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Mon, 24 Feb 2025 12:06:40 +0000 Subject: [PATCH 6/9] Update src/diffusers/pipelines/flux/pipeline_flux.py Co-authored-by: hlky --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index fcf1e2713191..da0be31fe09e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -876,9 +876,8 @@ def __call__( if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): - zeros_image = np.zeros((width, height, 3), dtype=np.uint8) - negative_ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :] - negative_ip_adapter_image_embeds = [negative_ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None From 1956da676ee841b2ad696b3393994cbdc548fa8a Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Mon, 24 Feb 2025 12:13:09 +0000 Subject: [PATCH 7/9] make style && make quality --- src/diffusers/loaders/ip_adapter.py | 3 +-- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 224f6374fd01..0da40643c0c4 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -21,6 +21,7 @@ from safetensors import safe_open from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict +from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -578,8 +579,6 @@ def LinearStrengthModel(start, finish, size): ``` """ - from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type - scale_type = Union[int, float] num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters num_layers = self.transformer.config.num_layers diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index da0be31fe09e..e49371c0d5d2 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -878,7 +878,7 @@ def __call__( ): negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters - + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): From f7f92233a64883bb809999984f2c4cad7d576ae0 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Mon, 24 Feb 2025 17:23:10 +0000 Subject: [PATCH 8/9] Updated ip_adapter test --- tests/pipelines/test_pipelines_common.py | 41 ++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 33a7fd9f2b49..a98de5c9eaf9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -527,7 +527,9 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N The following scenarios are tested: - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. + - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. + - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. """ # Raising the tolerance for this test when it's run on a CPU because we # compare against static slices and that can be shaky (with a VVVV low probability). @@ -545,6 +547,7 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N else: output_without_adapter = expected_pipe_slice + # 1. Single IP-Adapter test cases adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer) pipe.transformer._load_ip_adapter_weights(adapter_state_dict) @@ -578,6 +581,44 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" ) + # 2. Multi IP-Adapter test cases + adapter_state_dict_1 = create_flux_ip_adapter_state_dict(pipe.transformer) + adapter_state_dict_2 = create_flux_ip_adapter_state_dict(pipe.transformer) + pipe.transformer._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) + + # forward pass with multi ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + pipe.set_ip_adapter_scale([0.0, 0.0]) + output_without_multi_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with multi ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2 + pipe.set_ip_adapter_scale([42.0, 42.0]) + output_with_multi_adapter_scale = pipe(**inputs)[0] + if expected_pipe_slice is not None: + output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_multi_adapter_scale = np.abs( + output_without_multi_adapter_scale - output_without_adapter + ).max() + max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() + self.assertLess( + max_diff_without_multi_adapter_scale, + expected_max_diff, + "Output without multi-ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_multi_adapter_scale, + 1e-2, + "Output with multi-ip-adapter scale must be different from normal inference", + ) + class PipelineLatentTesterMixin: """ From 905e8d789442b0c75f1ba878e6205aecb255b251 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Mon, 24 Feb 2025 18:20:47 +0000 Subject: [PATCH 9/9] Created typing_utils.py --- docs/source/en/_toctree.yml | 8 +- src/diffusers/loaders/ip_adapter.py | 3 +- .../pipelines/pipeline_loading_utils.py | 75 +-------------- src/diffusers/pipelines/pipeline_utils.py | 4 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/typing_utils.py | 91 +++++++++++++++++++ 6 files changed, 101 insertions(+), 81 deletions(-) create mode 100644 src/diffusers/utils/typing_utils.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a44a95911116..9f76be91339a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -543,6 +543,10 @@ title: Overview - local: api/schedulers/cm_stochastic_iterative title: CMStochasticIterativeScheduler + - local: api/schedulers/ddim_cogvideox + title: CogVideoXDDIMScheduler + - local: api/schedulers/multistep_dpm_solver_cogvideox + title: CogVideoXDPMScheduler - local: api/schedulers/consistency_decoder title: ConsistencyDecoderScheduler - local: api/schedulers/cosine_dpm @@ -551,8 +555,6 @@ title: DDIMInverseScheduler - local: api/schedulers/ddim title: DDIMScheduler - - local: api/schedulers/ddim_cogvideox - title: CogVideoXDDIMScheduler - local: api/schedulers/ddpm title: DDPMScheduler - local: api/schedulers/deis @@ -565,8 +567,6 @@ title: DPMSolverSDEScheduler - local: api/schedulers/singlestep_dpm_solver title: DPMSolverSinglestepScheduler - - local: api/schedulers/multistep_dpm_solver_cogvideox - title: CogVideoXDPMScheduler - local: api/schedulers/edm_multistep_dpm_solver title: EDMDPMSolverMultistepScheduler - local: api/schedulers/edm_euler diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 0da40643c0c4..33144090cbc6 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -21,10 +21,11 @@ from safetensors import safe_open from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict -from ..pipelines.pipeline_loading_utils import _get_detailed_type, _is_valid_type from ..utils import ( USE_PEFT_BACKEND, + _get_detailed_type, _get_model_file, + _is_valid_type, is_accelerate_available, is_torch_version, is_transformers_available, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 0e2cbb32d3c1..9a9afa198b4c 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -17,7 +17,7 @@ import re import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Union import requests import torch @@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): break if has_transformers_component and not is_transformers_version(">", "4.47.1"): raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") - - -def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: - """ - Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of - the correct type as well. - """ - if not isinstance(class_or_tuple, tuple): - class_or_tuple = (class_or_tuple,) - - # Unpack unions - unpacked_class_or_tuple = [] - for t in class_or_tuple: - if get_origin(t) is Union: - unpacked_class_or_tuple.extend(get_args(t)) - else: - unpacked_class_or_tuple.append(t) - class_or_tuple = tuple(unpacked_class_or_tuple) - - if Any in class_or_tuple: - return True - - obj_type = type(obj) - # Classes with obj's type - class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} - - # Singular types (e.g. int, ControlNet, ...) - # Untyped collections (e.g. List, but not List[int]) - elem_class_or_tuple = {get_args(t) for t in class_or_tuple} - if () in elem_class_or_tuple: - return True - # Typed lists or sets - elif obj_type in (list, set): - return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) - # Typed tuples - elif obj_type is tuple: - return any( - # Tuples with any length and single type (e.g. Tuple[int, ...]) - (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) - or - # Tuples with fixed length and any types (e.g. Tuple[int, str]) - (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) - for t in elem_class_or_tuple - ) - # Typed dicts - elif obj_type is dict: - return any( - all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) - for kt, vt in elem_class_or_tuple - ) - - else: - return False - - -def _get_detailed_type(obj: Any) -> Type: - """ - Gets a detailed type for an object, including nested types for collections. - """ - obj_type = type(obj) - - if obj_type in (list, set): - obj_origin_type = List if obj_type is list else Set - elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] - return obj_origin_type[elems_type] - elif obj_type is tuple: - return Tuple[tuple(_get_detailed_type(x) for x in obj)] - elif obj_type is dict: - keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] - values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] - return Dict[keys_type, values_type] - else: - return obj_type diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e112947c8d5a..1b306b1805d8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -54,6 +54,8 @@ DEPRECATED_REVISION_ARGS, BaseOutput, PushToHubMixin, + _get_detailed_type, + _is_valid_type, is_accelerate_available, is_accelerate_version, is_torch_npu_available, @@ -78,12 +80,10 @@ _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, - _get_detailed_type, _get_final_device_map, _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, - _is_valid_type, _maybe_raise_error_for_incorrect_transformers, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d82aded4c435..08b1713d0e31 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -123,6 +123,7 @@ convert_state_dict_to_peft, convert_unet_state_dict_to_peft, ) +from .typing_utils import _get_detailed_type, _is_valid_type logger = get_logger(__name__) diff --git a/src/diffusers/utils/typing_utils.py b/src/diffusers/utils/typing_utils.py new file mode 100644 index 000000000000..2b5b1a4f5ab5 --- /dev/null +++ b/src/diffusers/utils/typing_utils.py @@ -0,0 +1,91 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Typing utilities: Utilities related to type checking and validation +""" + +from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin + + +def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: + """ + Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of + the correct type as well. + """ + if not isinstance(class_or_tuple, tuple): + class_or_tuple = (class_or_tuple,) + + # Unpack unions + unpacked_class_or_tuple = [] + for t in class_or_tuple: + if get_origin(t) is Union: + unpacked_class_or_tuple.extend(get_args(t)) + else: + unpacked_class_or_tuple.append(t) + class_or_tuple = tuple(unpacked_class_or_tuple) + + if Any in class_or_tuple: + return True + + obj_type = type(obj) + # Classes with obj's type + class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} + + # Singular types (e.g. int, ControlNet, ...) + # Untyped collections (e.g. List, but not List[int]) + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} + if () in elem_class_or_tuple: + return True + # Typed lists or sets + elif obj_type in (list, set): + return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) + # Typed tuples + elif obj_type is tuple: + return any( + # Tuples with any length and single type (e.g. Tuple[int, ...]) + (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) + or + # Tuples with fixed length and any types (e.g. Tuple[int, str]) + (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) + for t in elem_class_or_tuple + ) + # Typed dicts + elif obj_type is dict: + return any( + all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) + for kt, vt in elem_class_or_tuple + ) + + else: + return False + + +def _get_detailed_type(obj: Any) -> Type: + """ + Gets a detailed type for an object, including nested types for collections. + """ + obj_type = type(obj) + + if obj_type in (list, set): + obj_origin_type = List if obj_type is list else Set + elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] + return obj_origin_type[elems_type] + elif obj_type is tuple: + return Tuple[tuple(_get_detailed_type(x) for x in obj)] + elif obj_type is dict: + keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] + values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] + return Dict[keys_type, values_type] + else: + return obj_type