|
140 | 140 | "net.blocks.0.self_attn.q_proj.weight", |
141 | 141 | "net.pos_embedder.dim_spatial_range", |
142 | 142 | ], |
| 143 | + "flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"], |
143 | 144 | } |
144 | 145 |
|
145 | 146 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { |
|
189 | 190 | "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, |
190 | 191 | "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, |
191 | 192 | "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, |
| 193 | + "flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"}, |
192 | 194 | "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, |
193 | 195 | "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, |
194 | 196 | "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, |
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint): |
649 | 651 | else: |
650 | 652 | model_type = "animatediff_v3" |
651 | 653 |
|
| 654 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]): |
| 655 | + model_type = "flux-2-dev" |
| 656 | + |
652 | 657 | elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): |
653 | 658 | if any( |
654 | 659 | g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] |
@@ -3647,3 +3652,168 @@ def rename_transformer_blocks_(key: str, state_dict): |
3647 | 3652 | handler_fn_inplace(key, converted_state_dict) |
3648 | 3653 |
|
3649 | 3654 | return converted_state_dict |
| 3655 | + |
| 3656 | + |
| 3657 | +def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 3658 | + FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { |
| 3659 | + # Image and text input projections |
| 3660 | + "img_in": "x_embedder", |
| 3661 | + "txt_in": "context_embedder", |
| 3662 | + # Timestep and guidance embeddings |
| 3663 | + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", |
| 3664 | + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", |
| 3665 | + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", |
| 3666 | + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", |
| 3667 | + # Modulation parameters |
| 3668 | + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", |
| 3669 | + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", |
| 3670 | + "single_stream_modulation.lin": "single_stream_modulation.linear", |
| 3671 | + # Final output layer |
| 3672 | + # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params |
| 3673 | + "final_layer.linear": "proj_out", |
| 3674 | + } |
| 3675 | + |
| 3676 | + FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = { |
| 3677 | + "final_layer.adaLN_modulation.1": "norm_out.linear", |
| 3678 | + } |
| 3679 | + |
| 3680 | + FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { |
| 3681 | + # Handle fused QKV projections separately as we need to break into Q, K, V projections |
| 3682 | + "img_attn.norm.query_norm": "attn.norm_q", |
| 3683 | + "img_attn.norm.key_norm": "attn.norm_k", |
| 3684 | + "img_attn.proj": "attn.to_out.0", |
| 3685 | + "img_mlp.0": "ff.linear_in", |
| 3686 | + "img_mlp.2": "ff.linear_out", |
| 3687 | + "txt_attn.norm.query_norm": "attn.norm_added_q", |
| 3688 | + "txt_attn.norm.key_norm": "attn.norm_added_k", |
| 3689 | + "txt_attn.proj": "attn.to_add_out", |
| 3690 | + "txt_mlp.0": "ff_context.linear_in", |
| 3691 | + "txt_mlp.2": "ff_context.linear_out", |
| 3692 | + } |
| 3693 | + |
| 3694 | + FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { |
| 3695 | + "linear1": "attn.to_qkv_mlp_proj", |
| 3696 | + "linear2": "attn.to_out", |
| 3697 | + "norm.query_norm": "attn.norm_q", |
| 3698 | + "norm.key_norm": "attn.norm_k", |
| 3699 | + } |
| 3700 | + |
| 3701 | + def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None: |
| 3702 | + # Skip if not a weight, bias, or scale |
| 3703 | + if ".weight" not in key and ".bias" not in key and ".scale" not in key: |
| 3704 | + return |
| 3705 | + |
| 3706 | + # Mapping: |
| 3707 | + # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj |
| 3708 | + # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out |
| 3709 | + # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight |
| 3710 | + # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight |
| 3711 | + new_prefix = "single_transformer_blocks" |
| 3712 | + if "single_blocks." in key: |
| 3713 | + parts = key.split(".") |
| 3714 | + block_idx = parts[1] |
| 3715 | + within_block_name = ".".join(parts[2:-1]) |
| 3716 | + param_type = parts[-1] |
| 3717 | + |
| 3718 | + if param_type == "scale": |
| 3719 | + param_type = "weight" |
| 3720 | + |
| 3721 | + new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name] |
| 3722 | + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) |
| 3723 | + |
| 3724 | + param = state_dict.pop(key) |
| 3725 | + state_dict[new_key] = param |
| 3726 | + |
| 3727 | + return |
| 3728 | + |
| 3729 | + def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None: |
| 3730 | + # Skip if not a weight |
| 3731 | + if ".weight" not in key: |
| 3732 | + return |
| 3733 | + |
| 3734 | + # If adaLN_modulation is in the key, swap scale and shift parameters |
| 3735 | + # Original implementation is (shift, scale); diffusers implementation is (scale, shift) |
| 3736 | + if "adaLN_modulation" in key: |
| 3737 | + key_without_param_type, param_type = key.rsplit(".", maxsplit=1) |
| 3738 | + # Assume all such keys are in the AdaLayerNorm key map |
| 3739 | + new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type] |
| 3740 | + new_key = ".".join([new_key_without_param_type, param_type]) |
| 3741 | + |
| 3742 | + swapped_weight = swap_scale_shift(state_dict.pop(key), 0) |
| 3743 | + state_dict[new_key] = swapped_weight |
| 3744 | + |
| 3745 | + return |
| 3746 | + |
| 3747 | + def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None: |
| 3748 | + # Skip if not a weight, bias, or scale |
| 3749 | + if ".weight" not in key and ".bias" not in key and ".scale" not in key: |
| 3750 | + return |
| 3751 | + |
| 3752 | + new_prefix = "transformer_blocks" |
| 3753 | + if "double_blocks." in key: |
| 3754 | + parts = key.split(".") |
| 3755 | + block_idx = parts[1] |
| 3756 | + modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp |
| 3757 | + within_block_name = ".".join(parts[2:-1]) |
| 3758 | + param_type = parts[-1] |
| 3759 | + |
| 3760 | + if param_type == "scale": |
| 3761 | + param_type = "weight" |
| 3762 | + |
| 3763 | + if "qkv" in within_block_name: |
| 3764 | + fused_qkv_weight = state_dict.pop(key) |
| 3765 | + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) |
| 3766 | + if "img" in modality_block_name: |
| 3767 | + # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v} |
| 3768 | + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) |
| 3769 | + new_q_name = "attn.to_q" |
| 3770 | + new_k_name = "attn.to_k" |
| 3771 | + new_v_name = "attn.to_v" |
| 3772 | + elif "txt" in modality_block_name: |
| 3773 | + # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj} |
| 3774 | + to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) |
| 3775 | + new_q_name = "attn.add_q_proj" |
| 3776 | + new_k_name = "attn.add_k_proj" |
| 3777 | + new_v_name = "attn.add_v_proj" |
| 3778 | + new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type]) |
| 3779 | + new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type]) |
| 3780 | + new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type]) |
| 3781 | + state_dict[new_q_key] = to_q_weight |
| 3782 | + state_dict[new_k_key] = to_k_weight |
| 3783 | + state_dict[new_v_key] = to_v_weight |
| 3784 | + else: |
| 3785 | + new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name] |
| 3786 | + new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type]) |
| 3787 | + |
| 3788 | + param = state_dict.pop(key) |
| 3789 | + state_dict[new_key] = param |
| 3790 | + return |
| 3791 | + |
| 3792 | + def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None: |
| 3793 | + state_dict[new_key] = state_dict.pop(old_key) |
| 3794 | + |
| 3795 | + TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| 3796 | + "adaLN_modulation": convert_ada_layer_norm_weights, |
| 3797 | + "double_blocks": convert_flux2_double_stream_blocks, |
| 3798 | + "single_blocks": convert_flux2_single_stream_blocks, |
| 3799 | + } |
| 3800 | + |
| 3801 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 3802 | + |
| 3803 | + # Handle official code --> diffusers key remapping via the remap dict |
| 3804 | + for key in list(converted_state_dict.keys()): |
| 3805 | + new_key = key[:] |
| 3806 | + for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items(): |
| 3807 | + new_key = new_key.replace(replace_key, rename_key) |
| 3808 | + |
| 3809 | + update_state_dict(converted_state_dict, key, new_key) |
| 3810 | + |
| 3811 | + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in |
| 3812 | + # special_keys_remap |
| 3813 | + for key in list(converted_state_dict.keys()): |
| 3814 | + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| 3815 | + if special_key not in key: |
| 3816 | + continue |
| 3817 | + handler_fn_inplace(key, converted_state_dict) |
| 3818 | + |
| 3819 | + return converted_state_dict |
0 commit comments