diff --git a/scripts/convert_step1x_edit_to_diffusers.py b/scripts/convert_step1x_edit_to_diffusers.py new file mode 100644 index 000000000000..888832cbbcc7 --- /dev/null +++ b/scripts/convert_step1x_edit_to_diffusers.py @@ -0,0 +1,597 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download +from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor + +from diffusers import AutoencoderKL, FluxTransformer2DModel, Step1XEditTransformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers.utils.import_utils import is_accelerate_available + + +""" +# Transformer + +python scripts/convert_step1x_edit_to_diffusers.py \ +--checkpoint_path "/mnt/lib/Step1X-Edit/step1x-edit-v1p1-official.safetensors" \ +--output_path "/mnt/lib/Step1X-Edit-diffusers" \ +--transformer + +""" + +""" +# VAE + +python scripts/convert_step1x_edit_to_diffusers.py \ +--checkpoint_path "/mnt/lib/FLUX.1-dev/ae.safetensors" \ +--output_path "/mnt/lib/Step1X-Edit-diffusers" \ +--dtype "fp32" \ +--vae + +""" + +""" +# LLM Encoder + +python scripts/convert_step1x_edit_to_diffusers.py \ +--original_state_dict_repo_id "/mnt/lib/Step1X-Edit/Qwen2.5-VL-7B-Instruct" \ +--output_path "/mnt/lib/Step1X-Edit-diffusers" \ +--text_encoder +""" + +""" +# Scheduler + +python scripts/convert_step1x_edit_to_diffusers.py \ +--original_state_dict_repo_id "/mnt/lib/FLUX.1-dev" \ +--output_path "/mnt/lib/Step1X-Edit-diffusers" \ +--scheduler +""" + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="flux.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--in_channels", type=int, default=64) +parser.add_argument("--out_channels", type=int, default=64) +parser.add_argument("--vae", action="store_true") +parser.add_argument("--text_encoder", action="store_true") +parser.add_argument("--transformer", action="store_true") +parser.add_argument("--scheduler", action="store_true") +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() +dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_step1x_edit_transformer_checkpoint_to_diffusers( + original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 +): + converted_state_dict = {} + + ## time_embed <- time_in + converted_state_dict["time_embed.in_layer.weight"] = original_state_dict.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_embed.in_layer.bias"] = original_state_dict.pop( + "time_in.in_layer.bias" + ) + converted_state_dict["time_embed.out_layer.weight"] = original_state_dict.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_embed.out_layer.bias"] = original_state_dict.pop( + "time_in.out_layer.bias" + ) + + ## vec_embed <- vector_in + converted_state_dict["vec_embed.in_layer.weight"] = original_state_dict.pop( + "vector_in.in_layer.weight" + ) + converted_state_dict["vec_embed.in_layer.bias"] = original_state_dict.pop( + "vector_in.in_layer.bias" + ) + converted_state_dict["vec_embed.out_layer.weight"] = original_state_dict.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["vec_embed.out_layer.bias"] = original_state_dict.pop( + "vector_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias") + + # connector + remaining_key = list(original_state_dict.keys()) + for key in remaining_key: + if 'connector' in key: + converted_state_dict[key] = original_state_dict.pop(key) + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.2.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transformer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.weight" + ) + converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.bias" + ) + + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.bias") + ) + + print(original_state_dict.keys()) + + return converted_state_dict + + +def convert_step1x_edit_vae_checkpoint_to_diffusers(original_state_dict): + converted_state_dict = {} + + # encoder.conv_in + converted_state_dict["encoder.conv_in.weight"] = original_state_dict.pop( + "encoder.conv_in.weight" + ) + converted_state_dict["encoder.conv_in.bias"] = original_state_dict.pop( + "encoder.conv_in.bias" + ) + + # encoder.conv_out + converted_state_dict["encoder.conv_out.weight"] = original_state_dict.pop( + "encoder.conv_out.weight" + ) + converted_state_dict["encoder.conv_out.bias"] = original_state_dict.pop( + "encoder.conv_out.bias" + ) + + # encoder.norm_out + converted_state_dict["encoder.conv_norm_out.weight"] = original_state_dict.pop( + "encoder.norm_out.weight" + ) + converted_state_dict["encoder.conv_norm_out.bias"] = original_state_dict.pop( + "encoder.norm_out.bias" + ) + + # encoder.down + for i in range(4): + # conv & norm + for j in range(2): + for k in range(1, 3): + converted_state_dict[f"encoder.down_blocks.{i}.resnets.{j}.conv{k}.weight"] = original_state_dict.pop( + f"encoder.down.{i}.block.{j}.conv{k}.weight" + ) + converted_state_dict[f"encoder.down_blocks.{i}.resnets.{j}.conv{k}.bias"] = original_state_dict.pop( + f"encoder.down.{i}.block.{j}.conv{k}.bias" + ) + converted_state_dict[f"encoder.down_blocks.{i}.resnets.{j}.norm{k}.weight"] = original_state_dict.pop( + f"encoder.down.{i}.block.{j}.norm{k}.weight" + ) + converted_state_dict[f"encoder.down_blocks.{i}.resnets.{j}.norm{k}.bias"] = original_state_dict.pop( + f"encoder.down.{i}.block.{j}.norm{k}.bias" + ) + + # downsample + if i != 3 : + converted_state_dict[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = original_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + converted_state_dict[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = original_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + # shortcut + if i == 1 or i == 2: + converted_state_dict[f"encoder.down_blocks.{i}.resnets.0.conv_shortcut.weight"] = original_state_dict.pop( + f"encoder.down.{i}.block.0.nin_shortcut.weight" + ) + converted_state_dict[f"encoder.down_blocks.{i}.resnets.0.conv_shortcut.bias"] = original_state_dict.pop( + f"encoder.down.{i}.block.0.nin_shortcut.bias" + ) + + # encoder.mid + converted_state_dict["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.q.weight" + ).squeeze() + converted_state_dict["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.q.bias" + ) + converted_state_dict["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.k.weight" + ).squeeze() + converted_state_dict["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.k.bias" + ) + converted_state_dict["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.v.weight" + ).squeeze() + converted_state_dict["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.v.bias" + ) + converted_state_dict["encoder.mid_block.attentions.0.group_norm.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.norm.weight" + ) + converted_state_dict["encoder.mid_block.attentions.0.group_norm.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.norm.bias" + ) + converted_state_dict["encoder.mid_block.attentions.0.to_out.0.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.proj_out.weight" + ).squeeze() + converted_state_dict["encoder.mid_block.attentions.0.to_out.0.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.proj_out.bias" + ) + + # encoder.mid_block + for i in range(2): + for j in range(2): + # conv + converted_state_dict[f"encoder.mid_block.resnets.{i}.conv{j+1}.weight"] = original_state_dict.pop( + f"encoder.mid.block_{i+1}.conv{j+1}.weight" + ) + converted_state_dict[f"encoder.mid_block.resnets.{i}.conv{j+1}.bias"] = original_state_dict.pop( + f"encoder.mid.block_{i+1}.conv{j+1}.bias" + ) + + # norm + converted_state_dict[f"encoder.mid_block.resnets.{i}.norm{j+1}.weight"] = original_state_dict.pop( + f"encoder.mid.block_{i+1}.norm{j+1}.weight" + ) + converted_state_dict[f"encoder.mid_block.resnets.{i}.norm{j+1}.bias"] = original_state_dict.pop( + f"encoder.mid.block_{i+1}.norm{j+1}.bias" + ) + + # decoder.conv_in + converted_state_dict["decoder.conv_in.weight"] = original_state_dict.pop( + "decoder.conv_in.weight" + ) + converted_state_dict["decoder.conv_in.bias"] = original_state_dict.pop( + "decoder.conv_in.bias" + ) + + # decoder.conv_out + converted_state_dict["decoder.conv_out.weight"] = original_state_dict.pop( + "decoder.conv_out.weight" + ) + converted_state_dict["decoder.conv_out.bias"] = original_state_dict.pop( + "decoder.conv_out.bias" + ) + + # decoder.norm_out + converted_state_dict["decoder.conv_norm_out.weight"] = original_state_dict.pop( + "decoder.norm_out.weight" + ) + converted_state_dict["decoder.conv_norm_out.bias"] = original_state_dict.pop( + "decoder.norm_out.bias" + ) + + # decoder.mid + converted_state_dict["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.q.weight" + ).squeeze() + converted_state_dict["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.q.bias" + ) + converted_state_dict["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.k.weight" + ).squeeze() + converted_state_dict["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.k.bias" + ) + converted_state_dict["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.v.weight" + ).squeeze() + converted_state_dict["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.v.bias" + ) + converted_state_dict["decoder.mid_block.attentions.0.group_norm.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.norm.weight" + ) + converted_state_dict["decoder.mid_block.attentions.0.group_norm.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.norm.bias" + ) + converted_state_dict["decoder.mid_block.attentions.0.to_out.0.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.proj_out.weight" + ).squeeze() + converted_state_dict["decoder.mid_block.attentions.0.to_out.0.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.proj_out.bias" + ) + + # decoder.mid_block + for i in range(2): + for j in range(2): + # conv + converted_state_dict[f"decoder.mid_block.resnets.{i}.conv{j+1}.weight"] = original_state_dict.pop( + f"decoder.mid.block_{i+1}.conv{j+1}.weight" + ) + converted_state_dict[f"decoder.mid_block.resnets.{i}.conv{j+1}.bias"] = original_state_dict.pop( + f"decoder.mid.block_{i+1}.conv{j+1}.bias" + ) + + # norm + converted_state_dict[f"decoder.mid_block.resnets.{i}.norm{j+1}.weight"] = original_state_dict.pop( + f"decoder.mid.block_{i+1}.norm{j+1}.weight" + ) + converted_state_dict[f"decoder.mid_block.resnets.{i}.norm{j+1}.bias"] = original_state_dict.pop( + f"decoder.mid.block_{i+1}.norm{j+1}.bias" + ) + + # decoder.up + for i in range(4): + # conv & norm + for j in range(3): + for k in range(1, 3): + converted_state_dict[f"decoder.up_blocks.{3-i}.resnets.{j}.conv{k}.weight"] = original_state_dict.pop( + f"decoder.up.{i}.block.{j}.conv{k}.weight" + ) + converted_state_dict[f"decoder.up_blocks.{3-i}.resnets.{j}.conv{k}.bias"] = original_state_dict.pop( + f"decoder.up.{i}.block.{j}.conv{k}.bias" + ) + converted_state_dict[f"decoder.up_blocks.{3-i}.resnets.{j}.norm{k}.weight"] = original_state_dict.pop( + f"decoder.up.{i}.block.{j}.norm{k}.weight" + ) + converted_state_dict[f"decoder.up_blocks.{3-i}.resnets.{j}.norm{k}.bias"] = original_state_dict.pop( + f"decoder.up.{i}.block.{j}.norm{k}.bias" + ) + + # downsample + if i != 0 : + converted_state_dict[f"decoder.up_blocks.{3-i}.upsamplers.0.conv.weight"] = original_state_dict.pop( + f"decoder.up.{i}.upsample.conv.weight" + ) + converted_state_dict[f"decoder.up_blocks.{3-i}.upsamplers.0.conv.bias"] = original_state_dict.pop( + f"decoder.up.{i}.upsample.conv.bias" + ) + + # shortcut + if i == 0 or i == 1: + converted_state_dict[f"decoder.up_blocks.{3-i}.resnets.0.conv_shortcut.weight"] = original_state_dict.pop( + f"decoder.up.{i}.block.0.nin_shortcut.weight" + ) + converted_state_dict[f"decoder.up_blocks.{3-i}.resnets.0.conv_shortcut.bias"] = original_state_dict.pop( + f"decoder.up.{i}.block.0.nin_shortcut.bias" + ) + + return converted_state_dict + + +def main(args): + + if args.transformer: + original_ckpt = load_original_checkpoint(args) + + num_layers = 19 + num_single_layers = 38 + inner_dim = 3072 + mlp_ratio = 4.0 + + converted_transformer_state_dict = convert_step1x_edit_transformer_checkpoint_to_diffusers( + original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio + ) + transformer = Step1XEditTransformer2DModel( + in_channels=args.in_channels, out_channels=args.out_channels + ) + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") + + if args.vae: + original_ckpt = load_original_checkpoint(args) + converted_vae_state_dict = convert_step1x_edit_vae_checkpoint_to_diffusers(original_ckpt) + vae = AutoencoderKL( + in_channels = 3, + out_channels = 3, + down_block_types = [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + up_block_types = [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + block_out_channels = [ + 128, + 256, + 512, + 512 + ], + layers_per_block = 2, + act_fn = "silu", + latent_channels = 16, + norm_num_groups = 32, + sample_size = 1024, + scaling_factor = 0.3611, + shift_factor = 0.1159, + latents_mean = None, + latents_std = None, + force_upcast = True, + use_quant_conv = False, + use_post_quant_conv = False, + mid_block_add_attention = True, + ) + vae.load_state_dict(converted_vae_state_dict, strict=True) + vae.to(dtype).save_pretrained(f"{args.output_path}/vae") + + if args.text_encoder: + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + args.original_state_dict_repo_id, + torch_dtype=dtype, + attn_implementation="sdpa", + ) + image_encoder = AutoProcessor.from_pretrained(args.original_state_dict_repo_id, min_pixels = 256 * 28 * 28, max_pixels = 324 * 28 * 28) + text_encoder.save_pretrained(f"{args.output_path}/text_encoder") + image_encoder.save_pretrained(f"{args.output_path}/processor") + + if args.scheduler: + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.original_state_dict_repo_id, + subfolder="scheduler" + ) + scheduler.save_pretrained(f"{args.output_path}/scheduler") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 762ae3846a7d..08fe4db9fcba 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -230,6 +230,7 @@ "SparseControlNetModel", "StableAudioDiTModel", "StableCascadeUNet", + "Step1XEditTransformer2DModel", "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", @@ -576,6 +577,7 @@ "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", + "Step1XEditPipeline", "TextToVideoSDPipeline", "TextToVideoZeroPipeline", "TextToVideoZeroSDXLPipeline", @@ -900,6 +902,7 @@ SkyReelsV2Transformer3DModel, SparseControlNetModel, StableAudioDiTModel, + Step1XEditTransformer2DModel, T2IAdapter, T5FilmDecoder, Transformer2DModel, @@ -1216,6 +1219,7 @@ StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, + Step1XEditPipeline, TextToVideoSDPipeline, TextToVideoZeroPipeline, TextToVideoZeroSDXLPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 49ac2a1c56fd..b09e66650466 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -97,6 +97,7 @@ _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] + _import_structure["transformers.transformer_step1x_edit"] = ["Step1XEditTransformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] @@ -193,6 +194,7 @@ SD3Transformer2DModel, SkyReelsV2Transformer3DModel, StableAudioDiTModel, + Step1XEditTransformer2DModel, T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..7511aee2e50a 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel + from .transformer_step1x_edit import Step1XEditTransformer2DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_step1x_edit.py b/src/diffusers/models/transformers/transformer_step1x_edit.py new file mode 100644 index 000000000000..26fbf503bdbd --- /dev/null +++ b/src/diffusers/models/transformers/transformer_step1x_edit.py @@ -0,0 +1,1101 @@ +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import math +from functools import partial +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.import_utils import is_torch_npu_available +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import ( + Timesteps, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "Step1XEditAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "Step1XEditAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "Step1XEditAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +def apply_gate(x, gate=None, tanh=False): + """Applies a gating mechanism to the input tensor + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +class Step1XEditAttnProcessor: + _attention_backend = None + + def __init__(self): + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "Step1XEditAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, backend=self._attention_backend + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class Step1XEditAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = Step1XEditAttnProcessor + _available_processors = [ + Step1XEditAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-6, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class Step1XEditSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = Step1XEditAttnProcessor() + + self.attn = Step1XEditAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class Step1XEditTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = Step1XEditAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=Step1XEditAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = hidden_states + ff_output + + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class Step1XEditPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Step1XEditMLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, device=device, dtype=dtype) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Step1XEditMLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def set_gradient_checkpointing(self, enable: bool): + self.gradient_checkpointing = enable + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class Step1XEditCrossAttnBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + qk_norm: bool = False, + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6 + ) + self.norm1_2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6 + ) + self.self_attn_q = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias + ) + self.self_attn_kv = nn.Linear( + hidden_size, hidden_size*2, bias=qkv_bias + ) + qk_norm_layer = nn.LayerNorm + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6 + ) + act_layer = nn.SiLU + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor=None, + + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + norm_y = self.norm1_2(y) + q = self.self_attn_q(norm_x) + q = q.view(q.size(0), q.size(1), self.heads_num, -1).permute(0, 2, 1, 3).contiguous() + kv = self.self_attn_kv(norm_y) + k, v = kv.view(kv.size(0), kv.size(1), 2, self.heads_num, -1).permute(2, 0, 3, 1, 4).contiguous().unbind(0) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).transpose(1, 2) + attn = attn.reshape(attn.size(0), attn.size(1), -1) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + return x + + +class Step1XEditIndividualTokenRefinerBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + qk_norm: bool = False, + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6 + ) + self.self_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias + ) + qk_norm_layer = nn.LayerNorm + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias + ) + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6 + ) + act_layer = nn.SiLU + self.mlp = Step1XEditMLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + y: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = qkv.view(qkv.size(0), qkv.size(1), 3, self.heads_num, -1).permute(2, 0, 3, 1, 4).contiguous().unbind(0) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).transpose(1, 2) + attn = attn.reshape(attn.size(0), attn.size(1), -1) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + +class Step1XEditIndividualTokenRefiner(torch.nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + qk_norm: bool = False, + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + Step1XEditIndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + qk_norm=qk_norm, + qkv_bias=qkv_bias, + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + y:torch.Tensor=None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( + 1, 1, seq_len, 1 + ) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + for block in self.blocks: + x = block(x, c, self_attn_mask,y) + + return x + + +class Step1XEditTimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, hidden_size, bias=True + ), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore + nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding( + t, self.frequency_embedding_size, self.max_period + ).type(self.mlp[0].weight.dtype) # type: ignore + t_emb = self.mlp(t_freq) + return t_emb + + +class Step1XEditTextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + super().__init__() + self.linear_1 = nn.Linear( + in_features=in_channels, + out_features=hidden_size, + bias=True, + ) + self.act_1 = act_layer() + self.linear_2 = nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=True, + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class Step1XEditSingleTokenRefiner(torch.nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + qk_norm: bool = False, + qkv_bias: bool = True, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__() + self.attn_mode = attn_mode + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear( + in_channels, hidden_size, bias=True + ) + + act_layer = nn.SiLU + # Build timestep embedding layer + self.t_embedder = Step1XEditTimestepEmbedder(hidden_size, act_layer) + # Build context embedding layer + self.c_embedder = Step1XEditTextProjection( + in_channels, hidden_size, act_layer + ) + + self.individual_token_refiner = Step1XEditIndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + qk_norm=qk_norm, + qkv_bias=qkv_bias, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + x = self.individual_token_refiner(x, c, mask) + + return x + + +class Step1XEditConnector(torch.nn.Module): + def __init__( + self, + in_channels=3584, + hidden_size=4096, + heads_num=32, + depth=2, + ): + super().__init__() + + self.S = Step1XEditSingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth) + self.global_proj_out=nn.Linear(in_channels, 768) + + def forward(self, x, t, mask): + t = t * 1000 + mask_float = mask.unsqueeze(-1) # [b, s1, 1] + + x_mean = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + + global_out = self.global_proj_out(x_mean) + encoder_hidden_states = self.S(x,t,mask) + return encoder_hidden_states, global_out + + +class Step1XEditTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, + AttentionMixin, +): + """ + The Transformer model introduced in Step1X-Edit. + + Reference: https://arxiv.org/abs/2504.17761 + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `19`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `38`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `4096`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Step1XEditTransformerBlock", "Step1XEditSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["Step1XEditTransformerBlock", "Step1XEditSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + timestep_in_dim: int = 256, + vector_in_dim: int = 768, + connector_in_channels=3584, + connector_hidden_size=4096, + connector_heads_num=32, + connector_depth=2, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.pos_embed = Step1XEditPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_embed = Step1XEditMLPEmbedder(timestep_in_dim, self.inner_dim) + self.vec_embed = Step1XEditMLPEmbedder(vector_in_dim, self.inner_dim) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.connector = Step1XEditConnector( + connector_in_channels, + connector_hidden_size, + connector_heads_num, + connector_depth, + ) + + self.transformer_blocks = nn.ModuleList( + [ + Step1XEditTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + Step1XEditSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + prompt_embeds_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`Step1XEditTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + encoder_hidden_states, y = self.connector( + encoder_hidden_states, timestep, prompt_embeds_mask + ) + hidden_states = self.x_embedder(hidden_states) + + temb = self.time_embed(self.time_proj(timestep * 1000).to(timestep)) + temb = temb + self.vec_embed(y) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 25d5d213cf33..afb8378f0818 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -367,6 +367,7 @@ "TextToVideoZeroSDXLPipeline", "VideoToVideoSDPipeline", ] + _import_structure["step1x_edit"] = ["Step1XEditPipeline"] _import_structure["i2vgen_xl"] = ["I2VGenXLPipeline"] _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] _import_structure["unidiffuser"] = [ @@ -762,6 +763,7 @@ StableDiffusionXLPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline + from .step1x_edit import Step1XEditPipeline from .t2i_adapter import ( StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline, diff --git a/src/diffusers/pipelines/step1x_edit/__init__.py b/src/diffusers/pipelines/step1x_edit/__init__.py new file mode 100644 index 000000000000..62d2563d33b2 --- /dev/null +++ b/src/diffusers/pipelines/step1x_edit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["Step1XEditPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_step1x_edit"] = ["Step1XEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_step1x_edit import Step1XEditPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/step1x_edit/pipeline_output.py b/src/diffusers/pipelines/step1x_edit/pipeline_output.py new file mode 100644 index 000000000000..afe0df503010 --- /dev/null +++ b/src/diffusers/pipelines/step1x_edit/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class Step1XEditPipelineOutput(BaseOutput): + """ + Output class for Step1X-Edit pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/step1x_edit/pipeline_step1x_edit.py b/src/diffusers/pipelines/step1x_edit/pipeline_step1x_edit.py new file mode 100644 index 000000000000..badddaadf228 --- /dev/null +++ b/src/diffusers/pipelines/step1x_edit/pipeline_step1x_edit.py @@ -0,0 +1,1020 @@ +# Copyright 2025 Step1X-Edit Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import math +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, Step1XEditTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler + +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import Step1XEditPipelineOutput + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Step1XEditPipeline + >>> from diffusers.utils import load_image + + >>> pipe = Step1XEditPipeline.from_pretrained("stepfun-ai/Step1X-Edit-v1p1-diffusers", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = "Make Pikachu hold a sign that says 'Step1X-Edit is awesome', yarn art style, detailed, vibrant colors" + + >>> image = pipe( + image=image, + prompt=prompt, + num_inference_steps=28, + true_cfg_scale=6.0, + generator=torch.Generator().manual_seed(42), + ).images[0] + >>> image.save("output.png") + ``` +""" + + + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Step1XEditPipeline(DiffusionPipeline): + r""" + The Step1X-Edit pipeline for image-to-image and text-to-image generation. + + Reference: https://arxiv.org/abs/2504.17761 + + Args: + transformer ([`Step1XEditTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) + processor (`Qwen2_5_VLProcessor`): + [Qwen2_5_VLProcessor](https://huggingface.co/docs/transformers/v4.53.3/en/model_doc/qwen2_5_vl#transformers.Qwen2_5_VLProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + processor: Qwen2_5_VLProcessor, + transformer: Step1XEditTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.image_encoder=None + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Step1X-Edit latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.max_token_length = 640 + self.default_sample_size = 128 + self.QWEN25VL_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: +- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. +- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n +Here are examples of how to transform or refine prompts: +- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. +- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n +Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: +User Prompt:''' + + def _split_string(self, s): + s = s.replace("'", '"').replace("“", '"').replace("”", '"') # use english quotes + result = [] + in_quotes = False + temp = "" + + for idx,char in enumerate(s): + if char == '"' and idx>155: # system token + temp += char + if not in_quotes: + result.append(temp) + temp = "" + + in_quotes = not in_quotes + continue + if in_quotes: + if char.isspace(): + pass # have space token + + result.append("“" + char + "”") + else: + temp += char + + if temp: + result.append(temp) + + return result + + def _get_qwenvl_embeds( + self, + prompt: Union[str, List[str]], + ref_image: Optional[torch.Tensor], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = torch.bfloat16, + ): + text_list = prompt + embs = torch.zeros( + len(text_list), + self.max_token_length, + self.text_encoder.config.hidden_size, + dtype=dtype, + device=device, + ) + hidden_states = torch.zeros( + len(text_list), + self.max_token_length, + self.text_encoder.config.hidden_size, + dtype=dtype, + device=device, + ) + masks = torch.zeros( + len(text_list), + self.max_token_length, + dtype=torch.long, + device=device, + ) + + for idx, (txt, imgs) in enumerate(zip(text_list, ref_image)): + + messages = [ + { + "role": "user", + "content": [] + } + ] + + messages[0]["content"].append({"type": "text", "text": f"{self.QWEN25VL_PREFIX}"}) + messages[0]['content'].append({"type": "image", "image": imgs}) + messages[0]["content"].append({"type": "text", "text": f"{txt}"}) + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, add_vision_id=True + ) + imgs = imgs.convert("RGB") + min_pixels = 4 * 28 * 28 + max_pixels = 16384 * 28 * 28 + width, height = imgs.size + h_bar = max(28, round(height / 28) * 28) + w_bar = max(28, round(width / 28) * 28) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / 28) * 28 + w_bar = math.floor(width / beta / 28) * 28 + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / 28) * 28 + w_bar = math.ceil(width * beta / 28) * 28 + image_inputs = [imgs.resize((w_bar, h_bar))] + + inputs = self.processor( + text=[text], + images=image_inputs, + padding=True, + return_tensors="pt", + ) + + old_inputs_ids = inputs.input_ids + text_split_list = self._split_string(text) + + token_list = [] + for text_each in text_split_list: + txt_inputs = self.processor( + text=text_each, + images=None, + videos=None, + padding=True, + return_tensors="pt", + ) + token_each=txt_inputs.input_ids + if token_each[0][0] == 2073 and token_each[0][-1] == 854: + token_each = token_each[:,1:-1] + token_list.append(token_each) + else: + token_list.append(token_each) + + new_txt_ids=torch.cat(token_list,dim=1).to("cuda") + + new_txt_ids = new_txt_ids.to(old_inputs_ids.device) + idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] + idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] + inputs.input_ids = torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]],dim=0).unsqueeze(0).to("cuda") + inputs.attention_mask= (inputs.input_ids>0).long().to("cuda") + outputs = self.text_encoder(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask, pixel_values = inputs.pixel_values.to("cuda"), image_grid_thw = inputs.image_grid_thw.to("cuda"), output_hidden_states=True) + + emb = outputs['hidden_states'][-1] + embs[idx,:min(self.max_token_length,emb.shape[1]-217)] = emb[0,217:][:self.max_token_length] + masks[idx,:min(self.max_token_length,emb.shape[1]-217)]=torch.ones((min(self.max_token_length,emb.shape[1]-217)), dtype=torch.long, device=torch.cuda.current_device()) + + return embs, masks + + def encode_prompt( + self, + ref_image: Optional[torch.Tensor], + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + ref_image = [ref_image] if isinstance(prompt, str) else ref_image # change + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwenvl_embeds(prompt, ref_image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device) + + return prompt_embeds, prompt_embeds_mask, text_ids + + def encode_image( + self, + image: Optional[torch.Tensor], + width: Optional[int] = None, + height: Optional[int] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + ): + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + img_info = image.size + width, height = img_info + aspect_ratio = width / height + + if width > height: + width_new = math.ceil(math.sqrt(1024 * 1024 * aspect_ratio)) + height_new = math.ceil(width_new / aspect_ratio) + else: + height_new = math.ceil(math.sqrt(1024 * 1024 / aspect_ratio)) + width_new = math.ceil(height_new * aspect_ratio) + + multiple_of = self.vae_scale_factor * 2 + height_new = height_new // multiple_of * multiple_of + width_new = width_new // multiple_of * multiple_of + + if height != height_new or width != width_new: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height_new} and {width_new} to fit the model requirements." + ) + height, width = height_new, width_new + ref_image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(ref_image, height, width).contiguous() + else: + width = width if width is not None else 1024 + height = height if height is not None else 1024 + img_info = (width, height) + ref_image = torch.zeros(3, 1024, 1024).unsqueeze(0).to(device) + ref_image = self.image_processor.pt_to_numpy(ref_image) + ref_image = self.image_processor.numpy_to_pil(ref_image)[0] + image = None + + return image, ref_image, img_info, width, height + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @staticmethod + def _output_process_image(image, image_size): + res_image = [img.resize(image_size) for img in image] + return res_image + + @staticmethod + def process_diff_norm(diff_norm, k): + pow_result = torch.pow(diff_norm, k) + + result = torch.where( + diff_norm > 1.0, + pow_result, + torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm), + ) + return result + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + image_latents = None + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + image: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = self._prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, torch.float32 # change + # batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + image_ids[..., 1] = image_ids[..., 1] + 1 + image_ids[..., 2] = image_ids[..., 2] + 1 + latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # change + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latent_ids, image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 6.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 6.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + timesteps_truncate: float = 0.93, + process_norm_power: float = 0.4 + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 6.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.step1x_edit.Step1XEditPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.step1x_edit.Step1XEditPipelineOutput`] or `tuple`: + [`~pipelines.step1x_edit.Step1XEditPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + device = self._execution_device + + # 1. Preprocess image + image, ref_image, img_info, width, height = self.encode_image( + image, + width, + height, + device, + num_images_per_prompt + ) + + # 2. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + if not has_neg_prompt: + negative_prompt = "" if image is not None else "worst quality, wrong limbs, unreasonable limbs, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting" + do_true_cfg = true_cfg_scale > 1 + ( + prompt_embeds, + prompt_embeds_mask, + text_ids + ) = self.encode_prompt( + ref_image=ref_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_text_ids, + ) = self.encode_prompt( + ref_image=ref_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, latent_ids, image_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + if t.item() > timesteps_truncate: + diff = noise_pred - neg_noise_pred + diff_norm = torch.norm(diff, dim=(2), keepdim=True) + noise_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) / self.process_diff_norm(diff_norm, k=process_norm_power) + else: + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + image = self._output_process_image(image, img_info) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Step1XEditPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 91eefc5c10e0..0d5c6db62c03 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2987,6 +2987,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Step1XEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class TextToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]