Skip to content

Commit 416ff5d

Browse files
committed
flux
1 parent 4724606 commit 416ff5d

24 files changed

+2387
-18
lines changed

diffsynth/configs/model_configs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,34 @@
285285
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
286286
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
287287
},
288+
{
289+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
290+
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
291+
"model_name": "flux_vae_encoder",
292+
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
293+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
294+
},
295+
{
296+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
297+
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
298+
"model_name": "flux_vae_decoder",
299+
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
300+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
301+
},
302+
{
303+
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
304+
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
305+
"model_name": "flux_dit",
306+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
307+
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
308+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
309+
},
310+
{
311+
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
312+
"model_hash": "0629116fce1472503a66992f96f3eb1a",
313+
"model_name": "flux_value_controller",
314+
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
315+
}
288316
]
289317

290318
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series

diffsynth/models/flux_ipadapter.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
1-
from .svd_image_encoder import SVDImageEncoder
2-
from .sd3_dit import RMSNorm
3-
from transformers import CLIPImageProcessor
1+
from .general_modules import RMSNorm
2+
from transformers import SiglipVisionModel, SiglipVisionConfig
43
import torch
54

65

6+
class SiglipVisionModelSO400M(SiglipVisionModel):
7+
def __init__(self):
8+
config = SiglipVisionConfig(**{
9+
"architectures": [
10+
"SiglipModel"
11+
],
12+
"initializer_factor": 1.0,
13+
"model_type": "siglip",
14+
"text_config": {
15+
"hidden_size": 1152,
16+
"intermediate_size": 4304,
17+
"model_type": "siglip_text_model",
18+
"num_attention_heads": 16,
19+
"num_hidden_layers": 27
20+
},
21+
"torch_dtype": "float32",
22+
"transformers_version": "4.37.0.dev0",
23+
"vision_config": {
24+
"hidden_size": 1152,
25+
"image_size": 384,
26+
"intermediate_size": 4304,
27+
"model_type": "siglip_vision_model",
28+
"num_attention_heads": 16,
29+
"num_hidden_layers": 27,
30+
"patch_size": 14
31+
}
32+
})
33+
super().__init__(config)
34+
35+
736
class MLPProjModel(torch.nn.Module):
837
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
938
super().__init__()

diffsynth/models/flux_vae.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_ba
106106
return model_output
107107

108108

109-
class Attention(torch.nn.Module):
109+
class ConvAttention(torch.nn.Module):
110110

111111
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
112112
super().__init__()
@@ -115,20 +115,25 @@ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_k
115115
self.num_heads = num_heads
116116
self.head_dim = head_dim
117117

118-
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
119-
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
120-
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
121-
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
118+
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
119+
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
120+
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
121+
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
122122

123123
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
124124
if encoder_hidden_states is None:
125125
encoder_hidden_states = hidden_states
126126

127127
batch_size = encoder_hidden_states.shape[0]
128128

129-
q = self.to_q(hidden_states)
130-
k = self.to_k(encoder_hidden_states)
131-
v = self.to_v(encoder_hidden_states)
129+
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
130+
q = self.to_q(conv_input)
131+
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
132+
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
133+
k = self.to_k(conv_input)
134+
v = self.to_v(conv_input)
135+
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
136+
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
132137

133138
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
134139
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
@@ -138,7 +143,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
138143
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
139144
hidden_states = hidden_states.to(q.dtype)
140145

141-
hidden_states = self.to_out(hidden_states)
146+
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
147+
hidden_states = self.to_out(conv_input)
148+
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
142149

143150
return hidden_states
144151

@@ -152,7 +159,7 @@ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_lay
152159
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
153160

154161
self.transformer_blocks = torch.nn.ModuleList([
155-
Attention(
162+
ConvAttention(
156163
inner_dim,
157164
num_attention_heads,
158165
attention_head_dim,
@@ -236,7 +243,7 @@ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
236243
return hidden_states, time_emb, text_emb, res_stack
237244

238245

239-
class SD3VAEDecoder(torch.nn.Module):
246+
class FluxVAEDecoder(torch.nn.Module):
240247
def __init__(self):
241248
super().__init__()
242249
self.scaling_factor = 0.3611
@@ -308,7 +315,7 @@ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
308315
return hidden_states
309316

310317

311-
class SD3VAEEncoder(torch.nn.Module):
318+
class FluxVAEEncoder(torch.nn.Module):
312319
def __init__(self):
313320
super().__init__()
314321
self.scaling_factor = 0.3611

diffsynth/models/flux_value_control.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
2-
from diffsynth.models.svd_unet import TemporalTimesteps
2+
from .general_modules import TemporalTimesteps
33

44

55
class MultiValueEncoder(torch.nn.Module):
66
def __init__(self, encoders=()):
77
super().__init__()
8+
if not isinstance(encoders, list):
9+
encoders = [encoders]
810
self.encoders = torch.nn.ModuleList(encoders)
911

1012
def __call__(self, values, dtype):

0 commit comments

Comments
 (0)