-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add Step1X-Edit Pipeline #12249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Step1X-Edit Pipeline #12249
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting this started! Looks like a very cool model. I think this PR is already a very good start.
@linoytsaban / @asomoza in case you have some time to check it out.
| processor._attention_backend = "_native_xla" | ||
| return processor | ||
|
|
||
| class Step1XEditAttnProcessor2_0_NPU: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this processor for now.
|
|
||
|
|
||
| def apply_gate(x, gate=None, tanh=False): | ||
| """AI is creating summary for apply_gate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this description?
| return _get_projections(attn, hidden_states, encoder_hidden_states) | ||
|
|
||
|
|
||
| def get_activation_layer(act_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the activations don't vary across different blocks, can we remove this function and just use the activation functions in-place?
| return x * gate.unsqueeze(1) | ||
|
|
||
|
|
||
| def get_norm_layer(norm_layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above. It seems like the norm layers aren't changing. So, let's directly use nn.LayerNorm.
| self.to_v_ip = nn.ModuleList( | ||
| [ | ||
| nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) | ||
| for _ in range(len(num_tokens)) | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we already support IP adapters for this model? If so, could you include an example? If not, let's remove this.
| num_images_per_prompt: int = 1, | ||
| prompt_embeds: Optional[torch.Tensor] = None, | ||
| prompt_embeds_mask: Optional[torch.Tensor] = None, | ||
| max_sequence_length: int = 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's not used, let's remove.
| if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): | ||
| img_info = image.size | ||
| width, height = img_info | ||
| r = width / height |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| r = width / height | |
| aspect_ratio = width / height |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass | ||
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is | ||
| not greater than `1`). | ||
| true_cfg_scale (`float`, *optional*, defaults to 6.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see we have both guidance_scale and true_cfg_scale. Is this support future guidance-distilled models as the model doesn't seem to be a guidance-distilled model?
| guidance_scale (`float`, *optional*, defaults to 6.0): | ||
| Guidance scale as defined in [Classifier-Free Diffusion | ||
| Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. | ||
| of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting | ||
| `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to | ||
| the text `prompt`, usually at the expense of lower image quality. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In presence of the true_cfg_scale argument, we need to change this definition a bit:
| guidance_scale (`float`, *optional*, defaults to 3.5): |
| size_level (`int` defaults to 512): The maximum size level of the generated image in pixels. The height and width will be adjusted to fit this | ||
| area while maintaining the aspect ratio. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we derive this from the requested height and width parameters? Our pipelines don't ever contain arguments like size_level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR!
I left some comments
| import numpy as np | ||
| import torch | ||
| import math | ||
| from qwen_vl_utils import process_vision_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you try to not have this dependency?
| self.gradient_checkpointing = False | ||
|
|
||
| @staticmethod | ||
| def timestep_embedding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not make it a method of the transformer class
actually is it same as? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py#L1302
| if txt_ids.ndim == 3: | ||
| logger.warning( | ||
| "Passing `txt_ids` 3d torch.Tensor is deprecated." | ||
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | ||
| ) | ||
| txt_ids = txt_ids[0] | ||
| if img_ids.ndim == 3: | ||
| logger.warning( | ||
| "Passing `img_ids` 3d torch.Tensor is deprecated." | ||
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | ||
| ) | ||
| img_ids = img_ids[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if txt_ids.ndim == 3: | |
| logger.warning( | |
| "Passing `txt_ids` 3d torch.Tensor is deprecated." | |
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | |
| ) | |
| txt_ids = txt_ids[0] | |
| if img_ids.ndim == 3: | |
| logger.warning( | |
| "Passing `img_ids` 3d torch.Tensor is deprecated." | |
| "Please remove the batch dimension and pass it as a 2d torch Tensor" | |
| ) | |
| img_ids = img_ids[0] |
we don't need to deprecate for new model class
| if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: | ||
| ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") | ||
| ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) | ||
| joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: | |
| ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") | |
| ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) | |
| joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) |
there is no ip-adapter yet,no?
| # controlnet residual | ||
| if controlnet_block_samples is not None: | ||
| interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) | ||
| interval_control = int(np.ceil(interval_control)) | ||
| # For Xlabs ControlNet. | ||
| if controlnet_blocks_repeat: | ||
| hidden_states = ( | ||
| hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] | ||
| ) | ||
| else: | ||
| hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # controlnet residual | |
| if controlnet_block_samples is not None: | |
| interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) | |
| interval_control = int(np.ceil(interval_control)) | |
| # For Xlabs ControlNet. | |
| if controlnet_blocks_repeat: | |
| hidden_states = ( | |
| hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] | |
| ) | |
| else: | |
| hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] |
let's add controlnet when we have them:)
| x: torch.Tensor, | ||
| t: torch.LongTensor, | ||
| mask: Optional[torch.LongTensor] = None, | ||
| y: torch.LongTensor=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| y: torch.LongTensor=None, |
| if self.need_CA: | ||
| self.input_embedder_CA = nn.Linear( | ||
| in_channels, hidden_size, bias=True, **factory_kwargs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if self.need_CA: | |
| self.input_embedder_CA = nn.Linear( | |
| in_channels, hidden_size, bias=True, **factory_kwargs | |
| ) |
if this layer is not used in this checkpoint, let's just not have it
| if self.need_CA: | ||
| y = self.input_embedder_CA(y) | ||
| x = self.individual_token_refiner(x, c, mask, y) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if self.need_CA: | |
| y = self.input_embedder_CA(y) | |
| x = self.individual_token_refiner(x, c, mask, y) | |
| else: |
|
|
||
| global_out = self.global_proj_out(x_mean) | ||
|
|
||
| encoder_hidden_states = self.S(x,t,mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the SingleTokenRefiner should be its own layer, not part of connector: the inputs are passing through without processing here
so
encoder_hidden_states, mask -> global_proj -> global_out
encoder_hidden_states, timesteps, mask -> single token refiner -> encoder_hidden_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
@sayakpaul @yiyixuxu Thank you very much for your patient review. We've made some changes according to your feedback. We sincerely appreciate your efforts once again! |

What does this PR do?
This PR adds support for the Step1X-Edit model for image editing tasks, extending its integration within the Diffusers library. For further details regarding the Step1X-Edit model, please refer to the GitHub Repo and the Technical Report.
Example Code
Result
Init Image
Edited Image
Who can review?
cc @a-r-r-o-w @sayakpaul