|
18 | 18 | import torch |
19 | 19 | import torch.nn.functional as F |
20 | 20 | from torch import nn |
| 21 | +from einops import rearrange |
21 | 22 |
|
22 | 23 | from ..image_processor import IPAdapterMaskProcessor |
23 | 24 | from ..utils import deprecate, logging |
@@ -4800,6 +4801,144 @@ def __call__( |
4800 | 4801 | hidden_states = hidden_states / attn.rescale_output_factor |
4801 | 4802 |
|
4802 | 4803 | return hidden_states |
| 4804 | + |
| 4805 | + |
| 4806 | +class IPAdapterJointAttnProcessor2_0(torch.nn.Module): |
| 4807 | + """Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections.""" |
| 4808 | + |
| 4809 | + def __init__( |
| 4810 | + self, |
| 4811 | + hidden_size: int, |
| 4812 | + ip_hidden_states_dim: int, |
| 4813 | + head_dim: int, |
| 4814 | + timesteps_emb_dim: int = 1280, |
| 4815 | + scale: float = 0.5 |
| 4816 | + ): |
| 4817 | + super().__init__() |
| 4818 | + |
| 4819 | + # To prevent circular import |
| 4820 | + from .normalization import RMSNorm, AdaLayerNorm |
| 4821 | + |
| 4822 | + self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, |
| 4823 | + norm_eps=1e-6, chunk_dim=1) |
| 4824 | + self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) |
| 4825 | + self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) |
| 4826 | + self.norm_q = RMSNorm(head_dim, 1e-6) |
| 4827 | + self.norm_k = RMSNorm(head_dim, 1e-6) |
| 4828 | + self.norm_ip_k = RMSNorm(head_dim, 1e-6) |
| 4829 | + self.scale = scale |
| 4830 | + |
| 4831 | + def __call__( |
| 4832 | + self, |
| 4833 | + attn: Attention, |
| 4834 | + hidden_states: torch.FloatTensor, |
| 4835 | + encoder_hidden_states: torch.FloatTensor = None, |
| 4836 | + attention_mask: Optional[torch.FloatTensor] = None, |
| 4837 | + ip_hidden_states: torch.FloatTensor = None, |
| 4838 | + temb: torch.FloatTensor = None |
| 4839 | + ) -> torch.FloatTensor: |
| 4840 | + residual = hidden_states |
| 4841 | + |
| 4842 | + batch_size = hidden_states.shape[0] |
| 4843 | + |
| 4844 | + # `sample` projections. |
| 4845 | + query = attn.to_q(hidden_states) |
| 4846 | + key = attn.to_k(hidden_states) |
| 4847 | + value = attn.to_v(hidden_states) |
| 4848 | + img_query = query |
| 4849 | + img_key = key |
| 4850 | + img_value = value |
| 4851 | + |
| 4852 | + inner_dim = key.shape[-1] |
| 4853 | + head_dim = inner_dim // attn.heads |
| 4854 | + |
| 4855 | + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 4856 | + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 4857 | + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 4858 | + |
| 4859 | + if attn.norm_q is not None: |
| 4860 | + query = attn.norm_q(query) |
| 4861 | + if attn.norm_k is not None: |
| 4862 | + key = attn.norm_k(key) |
| 4863 | + |
| 4864 | + # `context` projections. |
| 4865 | + if encoder_hidden_states is not None: |
| 4866 | + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
| 4867 | + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| 4868 | + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
| 4869 | + |
| 4870 | + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
| 4871 | + batch_size, -1, attn.heads, head_dim |
| 4872 | + ).transpose(1, 2) |
| 4873 | + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
| 4874 | + batch_size, -1, attn.heads, head_dim |
| 4875 | + ).transpose(1, 2) |
| 4876 | + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
| 4877 | + batch_size, -1, attn.heads, head_dim |
| 4878 | + ).transpose(1, 2) |
| 4879 | + |
| 4880 | + if attn.norm_added_q is not None: |
| 4881 | + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
| 4882 | + if attn.norm_added_k is not None: |
| 4883 | + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
| 4884 | + |
| 4885 | + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) |
| 4886 | + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) |
| 4887 | + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) |
| 4888 | + |
| 4889 | + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) |
| 4890 | + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| 4891 | + hidden_states = hidden_states.to(query.dtype) |
| 4892 | + |
| 4893 | + if encoder_hidden_states is not None: |
| 4894 | + # Split the attention outputs. |
| 4895 | + hidden_states, encoder_hidden_states = ( |
| 4896 | + hidden_states[:, : residual.shape[1]], |
| 4897 | + hidden_states[:, residual.shape[1] :], |
| 4898 | + ) |
| 4899 | + if not attn.context_pre_only: |
| 4900 | + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| 4901 | + |
| 4902 | + # IP Adapter |
| 4903 | + if self.scale != 0 and ip_hidden_states is not None: |
| 4904 | + # Norm image features |
| 4905 | + norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) |
| 4906 | + |
| 4907 | + # To k and v |
| 4908 | + ip_key = self.to_k_ip(norm_ip_hidden_states) |
| 4909 | + ip_value = self.to_v_ip(norm_ip_hidden_states) |
| 4910 | + |
| 4911 | + # Reshape |
| 4912 | + img_query = rearrange(img_query, 'b l (h d) -> b h l d', h=attn.heads) |
| 4913 | + img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads) |
| 4914 | + img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads) |
| 4915 | + ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads) |
| 4916 | + ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads) |
| 4917 | + |
| 4918 | + # Norm |
| 4919 | + img_query = self.norm_q(img_query) |
| 4920 | + img_key = self.norm_k(img_key) |
| 4921 | + ip_key = self.norm_ip_k(ip_key) |
| 4922 | + |
| 4923 | + # cat img |
| 4924 | + img_key = torch.cat([img_key, ip_key], dim=2) |
| 4925 | + img_value = torch.cat([img_value, ip_value], dim=2) |
| 4926 | + |
| 4927 | + ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False) |
| 4928 | + ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)') |
| 4929 | + ip_hidden_states = ip_hidden_states.to(img_query.dtype) |
| 4930 | + |
| 4931 | + hidden_states = hidden_states + ip_hidden_states * self.scale |
| 4932 | + |
| 4933 | + # linear proj |
| 4934 | + hidden_states = attn.to_out[0](hidden_states) |
| 4935 | + # dropout |
| 4936 | + hidden_states = attn.to_out[1](hidden_states) |
| 4937 | + |
| 4938 | + if encoder_hidden_states is not None: |
| 4939 | + return hidden_states, encoder_hidden_states |
| 4940 | + else: |
| 4941 | + return hidden_states |
4803 | 4942 |
|
4804 | 4943 |
|
4805 | 4944 | class PAGIdentitySelfAttnProcessor2_0: |
@@ -5089,6 +5228,7 @@ def __init__(self): |
5089 | 5228 | IPAdapterAttnProcessor, |
5090 | 5229 | IPAdapterAttnProcessor2_0, |
5091 | 5230 | IPAdapterXFormersAttnProcessor, |
| 5231 | + IPAdapterJointAttnProcessor2_0, |
5092 | 5232 | PAGIdentitySelfAttnProcessor2_0, |
5093 | 5233 | PAGCFGIdentitySelfAttnProcessor2_0, |
5094 | 5234 | LoRAAttnProcessor, |
|
0 commit comments