Skip to content

Conversation

@Peyton-Chen
Copy link

@Peyton-Chen Peyton-Chen commented Aug 28, 2025

What does this PR do?

This PR adds support for the Step1X-Edit model for image editing tasks, extending its integration within the Diffusers library. For further details regarding the Step1X-Edit model, please refer to the GitHub Repo and the Technical Report.

Example Code

import torch
from diffusers import Step1XEditPipeline
from diffusers.utils import load_image

pipe = Step1XEditPipeline.from_pretrained("stepfun-ai/Step1X-Edit-v1p1-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
image = load_image(
    "https://github.com/stepfun-ai/Step1X-Edit/blob/main/examples/0000.jpg?raw=true"
).convert("RGB")
prompt = "Add pendant with a ruby around this girl's neck."

image = pipe(
    image=image,
    prompt=prompt, 
    num_inference_steps=28,
    size_level=1024,
    guidance_scale=6.0,
    generator=torch.Generator().manual_seed(1234),
).images[0]
image.save("output.png")

Result

Init Image

Init image

Edited Image

Mask

Who can review?

cc @a-r-r-o-w @sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting this started! Looks like a very cool model. I think this PR is already a very good start.

@linoytsaban / @asomoza in case you have some time to check it out.

processor._attention_backend = "_native_xla"
return processor

class Step1XEditAttnProcessor2_0_NPU:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this processor for now.



def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this description?

return _get_projections(attn, hidden_states, encoder_hidden_states)


def get_activation_layer(act_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the activations don't vary across different blocks, can we remove this function and just use the activation functions in-place?

return x * gate.unsqueeze(1)


def get_norm_layer(norm_layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. It seems like the norm layers aren't changing. So, let's directly use nn.LayerNorm.

Comment on lines 211 to 216
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we already support IP adapters for this model? If so, could you include an example? If not, let's remove this.

num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not used, let's remove.

if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
img_info = image.size
width, height = img_info
r = width / height
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
r = width / height
aspect_ratio = width / height

The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 6.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we have both guidance_scale and true_cfg_scale. Is this support future guidance-distilled models as the model doesn't seem to be a guidance-distilled model?

Comment on lines 726 to 731
guidance_scale (`float`, *optional*, defaults to 6.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In presence of the true_cfg_scale argument, we need to change this definition a bit:

guidance_scale (`float`, *optional*, defaults to 3.5):

Comment on lines 767 to 768
size_level (`int` defaults to 512): The maximum size level of the generated image in pixels. The height and width will be adjusted to fit this
area while maintaining the aspect ratio.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we derive this from the requested height and width parameters? Our pipelines don't ever contain arguments like size_level.

@sayakpaul sayakpaul requested a review from a-r-r-o-w August 30, 2025 07:12
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
I left some comments

import numpy as np
import torch
import math
from qwen_vl_utils import process_vision_info
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try to not have this dependency?

self.gradient_checkpointing = False

@staticmethod
def timestep_embedding(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not make it a method of the transformer class
actually is it same as? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py#L1302

Comment on lines 1324 to 1335
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]

we don't need to deprecate for new model class

Comment on lines 1340 to 1343
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

there is no ip-adapter yet,no?

Comment on lines 1365 to 1375
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]

let's add controlnet when we have them:)

x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
y: torch.LongTensor=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y: torch.LongTensor=None,

Comment on lines 1030 to 1033
if self.need_CA:
self.input_embedder_CA = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.need_CA:
self.input_embedder_CA = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)

if this layer is not used in this checkpoint, let's just not have it

Comment on lines 1077 to 1080
if self.need_CA:
y = self.input_embedder_CA(y)
x = self.individual_token_refiner(x, c, mask, y)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.need_CA:
y = self.input_embedder_CA(y)
x = self.individual_token_refiner(x, c, mask, y)
else:


global_out = self.global_proj_out(x_mean)

encoder_hidden_states = self.S(x,t,mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the SingleTokenRefiner should be its own layer, not part of connector: the inputs are passing through without processing here

so
encoder_hidden_states, mask -> global_proj -> global_out
encoder_hidden_states, timesteps, mask -> single token refiner -> encoder_hidden_state

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. We have resolved all other comments! Regarding the design of this connector, this class corresponds to the structural design outlined in the technical report, so we have retained this design.
image

@Peyton-Chen
Copy link
Author

@sayakpaul @yiyixuxu Thank you very much for your patient review. We've made some changes according to your feedback. We sincerely appreciate your efforts once again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants