Skip to content

Commit cceffc4

Browse files
committed
make style and make quality
1 parent dba4c1f commit cceffc4

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

scripts/convert_flux2_to_diffusers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import os
3-
import pathlib
43
from contextlib import nullcontext
54
from typing import Any, Dict, Optional, Tuple
65

@@ -11,7 +10,6 @@
1110

1211
from diffusers import Flux2Transformer2DModel
1312
from diffusers.utils.import_utils import is_accelerate_available
14-
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
1513

1614

1715
"""
@@ -22,7 +20,7 @@
2220
CTX = init_empty_weights if is_accelerate_available() else nullcontext
2321

2422

25-
FLUX2_TRANSFORMER_KEYS_RENAME_DICT ={
23+
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
2624
# Image and text input projections
2725
"img_in": "x_embedder",
2826
"txt_in": "context_embedder",
@@ -82,7 +80,7 @@ def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None
8280
# Skip if not a weight
8381
if ".weight" not in key:
8482
return
85-
83+
8684
# If adaLN_modulation is in the key, swap scale and shift parameters
8785
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
8886
if "adaLN_modulation" in key:
@@ -100,7 +98,7 @@ def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) ->
10098
# Skip if not a weight, bias, or scale
10199
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
102100
return
103-
101+
104102
new_prefix = "transformer_blocks"
105103
if "double_blocks." in key:
106104
parts = key.split(".")
@@ -111,7 +109,7 @@ def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) ->
111109

112110
if param_type == "scale":
113111
param_type = "weight"
114-
112+
115113
if "qkv" in within_block_name:
116114
fused_qkv_weight = state_dict.pop(key)
117115
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
@@ -146,7 +144,7 @@ def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) ->
146144
# Skip if not a weight, bias, or scale
147145
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
148146
return
149-
147+
150148
# Mapping:
151149
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
152150
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
@@ -215,7 +213,7 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
215213
"axes_dims_rope": (32, 32, 32, 32),
216214
"rope_theta": 2000,
217215
"eps": 1e-6,
218-
}
216+
},
219217
}
220218
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
221219
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import inspect
1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

18-
import numpy as np
1918
import torch
2019
import torch.nn as nn
2120
import torch.nn.functional as F

0 commit comments

Comments
 (0)