-
Couldn't load subscription status.
- Fork 6.5k
Feature IP Adapter Xformers Attention Processor #9881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
ac1c26d
a7af2b2
4475c0b
cd8702e
89f548c
37444bc
7741fb0
b01f302
4e9e4e0
3c66f70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -369,7 +369,20 @@ def set_use_memory_efficient_attention_xformers( | |||||||||||||||||||||
| ) | ||||||||||||||||||||||
| processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| processor = XFormersAttnProcessor(attention_op=attention_op) | ||||||||||||||||||||||
| processor = self.processor | ||||||||||||||||||||||
| if isinstance(self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): | ||||||||||||||||||||||
| processor = IPAdapterXFormersAttnProcessor(hidden_size=self.processor.hidden_size, | ||||||||||||||||||||||
| cross_attention_dim=self.processor.cross_attention_dim, | ||||||||||||||||||||||
| scale=self.processor.scale, | ||||||||||||||||||||||
| num_tokens=self.processor.num_tokens, | ||||||||||||||||||||||
| attention_op=attention_op) | ||||||||||||||||||||||
| processor.load_state_dict(self.processor.state_dict()) | ||||||||||||||||||||||
|
||||||||||||||||||||||
| self.to_k_ip = nn.ModuleList( |
(i had forgotten about that sorry! lol)
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this section of code do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is passing on the initialization parameters that were already defined during the loading of the ip adapter model to the new xformers attention class. After this I reload the state_dict already loaded in the new object, as already explained in the previous question. Then I just make sure that the weights are in the same device and dtype that were previously, because when reloading the state_dict they are placed in "cpu" and with dtype "float32".
I changed this initial if statement in line 380 to check for existing modules just to avoid unexpected errors.
if hasattr(self.processor, "_modules") and len(self.processor._modules) > 0:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I think we can simplify the code here a little bit because we are inside an if statement here so we already know the processor will be either IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0 or IPAdapterXFormersAttnProcessor -- in all of these 3 cases it will have a to_k_ip layer and to_v_ip layer, so maybe we can just get device info from self.to_k_ip[0].device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, I had done it in an agnostic way, because I wasn't sure that there would always be these modules in all the models that might arrive there. I'll change it to your solution.
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check can be removed because this class uses xformers.ops.memory_efficient_attention instead of torch.nn.functional.scaled_dot_product_attention
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op) |
just another observation: if we make tensors contiguous here, we can avoid multiple calls to query.contiguous() later in the code (everytime self. _memory_efficient_attention_xformers is called, query is reused)
this way, we can directly call xformers.ops.memory_efficient_attention
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ip_key = attn.head_to_batch_dim(ip_key) | |
| ip_value = attn.head_to_batch_dim(ip_value) | |
| _current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) | |
| ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
| ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
| _current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op) |
same as before
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ip_key = attn.head_to_batch_dim(ip_key) | |
| ip_value = attn.head_to_batch_dim(ip_value) | |
| current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) | |
| ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
| ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
| current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!Done!


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a
is_ip_adapterflag similar to is_custom_diffusion etcUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is perfectly possible, but it will have to be like below, so that the modules that have already been changed to the Xformers attention class are not replaced again to the XFormersAttnProcessor class in the final Else during the method recursion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the code you show here is ok!
we just want to keep a consistent style that's all:)