|  | 
| 94 | 94 |     "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", | 
| 95 | 95 |     "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", | 
| 96 | 96 |     "animatediff_rgb": "controlnet_cond_embedding.weight", | 
|  | 97 | +    "auraflow": [ | 
|  | 98 | +        "double_layers.0.attn.w2q.weight", | 
|  | 99 | +        "double_layers.0.attn.w1q.weight", | 
|  | 100 | +        "cond_seq_linear.weight", | 
|  | 101 | +        "t_embedder.mlp.0.weight", | 
|  | 102 | +    ], | 
| 97 | 103 |     "flux": [ | 
| 98 | 104 |         "double_blocks.0.img_attn.norm.key_norm.scale", | 
| 99 | 105 |         "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", | 
|  | 
| 154 | 160 |     "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, | 
| 155 | 161 |     "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, | 
| 156 | 162 |     "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, | 
|  | 163 | +    "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"}, | 
| 157 | 164 |     "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, | 
| 158 | 165 |     "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, | 
| 159 | 166 |     "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, | 
| @@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint): | 
| 635 | 642 |     elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: | 
| 636 | 643 |         model_type = "hunyuan-video" | 
| 637 | 644 | 
 | 
|  | 645 | +    elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]): | 
|  | 646 | +        model_type = "auraflow" | 
|  | 647 | + | 
| 638 | 648 |     elif ( | 
| 639 | 649 |         CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint | 
| 640 | 650 |         and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8 | 
| @@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): | 
| 2090 | 2100 | def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): | 
| 2091 | 2101 |     converted_state_dict = {} | 
| 2092 | 2102 |     keys = list(checkpoint.keys()) | 
|  | 2103 | + | 
| 2093 | 2104 |     for k in keys: | 
| 2094 | 2105 |         if "model.diffusion_model." in k: | 
| 2095 | 2106 |             checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) | 
| @@ -2689,3 +2700,95 @@ def update_state_dict_(state_dict, old_key, new_key): | 
| 2689 | 2700 |             handler_fn_inplace(key, checkpoint) | 
| 2690 | 2701 | 
 | 
| 2691 | 2702 |     return checkpoint | 
|  | 2703 | + | 
|  | 2704 | + | 
|  | 2705 | +def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): | 
|  | 2706 | +    converted_state_dict = {} | 
|  | 2707 | +    state_dict_keys = list(checkpoint.keys()) | 
|  | 2708 | + | 
|  | 2709 | +    # Handle register tokens and positional embeddings | 
|  | 2710 | +    converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) | 
|  | 2711 | + | 
|  | 2712 | +    # Handle time step projection | 
|  | 2713 | +    converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) | 
|  | 2714 | +    converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) | 
|  | 2715 | +    converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) | 
|  | 2716 | +    converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) | 
|  | 2717 | + | 
|  | 2718 | +    # Handle context embedder | 
|  | 2719 | +    converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) | 
|  | 2720 | + | 
|  | 2721 | +    # Calculate the number of layers | 
|  | 2722 | +    def calculate_layers(keys, key_prefix): | 
|  | 2723 | +        layers = set() | 
|  | 2724 | +        for k in keys: | 
|  | 2725 | +            if key_prefix in k: | 
|  | 2726 | +                layer_num = int(k.split(".")[1])  # get the layer number | 
|  | 2727 | +                layers.add(layer_num) | 
|  | 2728 | +        return len(layers) | 
|  | 2729 | + | 
|  | 2730 | +    mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") | 
|  | 2731 | +    single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") | 
|  | 2732 | + | 
|  | 2733 | +    # MMDiT blocks | 
|  | 2734 | +    for i in range(mmdit_layers): | 
|  | 2735 | +        # Feed-forward | 
|  | 2736 | +        path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} | 
|  | 2737 | +        weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} | 
|  | 2738 | +        for orig_k, diffuser_k in path_mapping.items(): | 
|  | 2739 | +            for k, v in weight_mapping.items(): | 
|  | 2740 | +                converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( | 
|  | 2741 | +                    f"double_layers.{i}.{orig_k}.{k}.weight", None | 
|  | 2742 | +                ) | 
|  | 2743 | + | 
|  | 2744 | +        # Norms | 
|  | 2745 | +        path_mapping = {"modX": "norm1", "modC": "norm1_context"} | 
|  | 2746 | +        for orig_k, diffuser_k in path_mapping.items(): | 
|  | 2747 | +            converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( | 
|  | 2748 | +                f"double_layers.{i}.{orig_k}.1.weight", None | 
|  | 2749 | +            ) | 
|  | 2750 | + | 
|  | 2751 | +        # Attentions | 
|  | 2752 | +        x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} | 
|  | 2753 | +        context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} | 
|  | 2754 | +        for attn_mapping in [x_attn_mapping, context_attn_mapping]: | 
|  | 2755 | +            for k, v in attn_mapping.items(): | 
|  | 2756 | +                converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( | 
|  | 2757 | +                    f"double_layers.{i}.attn.{k}.weight", None | 
|  | 2758 | +                ) | 
|  | 2759 | + | 
|  | 2760 | +    # Single-DiT blocks | 
|  | 2761 | +    for i in range(single_dit_layers): | 
|  | 2762 | +        # Feed-forward | 
|  | 2763 | +        mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} | 
|  | 2764 | +        for k, v in mapping.items(): | 
|  | 2765 | +            converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( | 
|  | 2766 | +                f"single_layers.{i}.mlp.{k}.weight", None | 
|  | 2767 | +            ) | 
|  | 2768 | + | 
|  | 2769 | +        # Norms | 
|  | 2770 | +        converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( | 
|  | 2771 | +            f"single_layers.{i}.modCX.1.weight", None | 
|  | 2772 | +        ) | 
|  | 2773 | + | 
|  | 2774 | +        # Attentions | 
|  | 2775 | +        x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} | 
|  | 2776 | +        for k, v in x_attn_mapping.items(): | 
|  | 2777 | +            converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( | 
|  | 2778 | +                f"single_layers.{i}.attn.{k}.weight", None | 
|  | 2779 | +            ) | 
|  | 2780 | +    # Final blocks | 
|  | 2781 | +    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) | 
|  | 2782 | + | 
|  | 2783 | +    # Handle the final norm layer | 
|  | 2784 | +    norm_weight = checkpoint.pop("modF.1.weight", None) | 
|  | 2785 | +    if norm_weight is not None: | 
|  | 2786 | +        converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) | 
|  | 2787 | +    else: | 
|  | 2788 | +        converted_state_dict["norm_out.linear.weight"] = None | 
|  | 2789 | + | 
|  | 2790 | +    converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") | 
|  | 2791 | +    converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") | 
|  | 2792 | +    converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") | 
|  | 2793 | + | 
|  | 2794 | +    return converted_state_dict | 
0 commit comments