Skip to content

Commit 6b573ae

Browse files
1 parent 015a059 commit 6b573ae

File tree

12 files changed

+506
-68
lines changed

12 files changed

+506
-68
lines changed

comfy/latent_formats.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ def process_in(self, latent):
178178
def process_out(self, latent):
179179
return (latent / self.scale_factor) + self.shift_factor
180180

181+
class Flux2(LatentFormat):
182+
latent_channels = 128
183+
184+
def process_in(self, latent):
185+
return latent
186+
187+
def process_out(self, latent):
188+
return latent
189+
181190
class Mochi(LatentFormat):
182191
latent_channels = 12
183192
latent_dimensions = 3

comfy/ldm/flux/layers.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
4848
return embedding
4949

5050
class MLPEmbedder(nn.Module):
51-
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
51+
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
5252
super().__init__()
53-
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
53+
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
5454
self.silu = nn.SiLU()
55-
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
55+
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
5656

5757
def forward(self, x: Tensor) -> Tensor:
5858
return self.out_layer(self.silu(self.in_layer(x)))
@@ -80,14 +80,14 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
8080

8181

8282
class SelfAttention(nn.Module):
83-
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
83+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
8484
super().__init__()
8585
self.num_heads = num_heads
8686
head_dim = dim // num_heads
8787

8888
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
8989
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
90-
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
90+
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
9191

9292

9393
@dataclass
@@ -98,11 +98,11 @@ class ModulationOut:
9898

9999

100100
class Modulation(nn.Module):
101-
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
101+
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
102102
super().__init__()
103103
self.is_double = double
104104
self.multiplier = 6 if double else 3
105-
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
105+
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
106106

107107
def forward(self, vec: Tensor) -> tuple:
108108
if vec.ndim == 2:
@@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
129129
return tensor
130130

131131

132+
class SiLUActivation(nn.Module):
133+
def __init__(self):
134+
super().__init__()
135+
self.gate_fn = nn.SiLU()
136+
137+
def forward(self, x: Tensor) -> Tensor:
138+
x1, x2 = x.chunk(2, dim=-1)
139+
return self.gate_fn(x1) * x2
140+
141+
132142
class DoubleStreamBlock(nn.Module):
133-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
143+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
134144
super().__init__()
135145

136146
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -142,27 +152,44 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
142152
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
143153

144154
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
145-
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
155+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
146156

147157
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
148-
self.img_mlp = nn.Sequential(
149-
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
150-
nn.GELU(approximate="tanh"),
151-
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
152-
)
158+
159+
if mlp_silu_act:
160+
self.img_mlp = nn.Sequential(
161+
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
162+
SiLUActivation(),
163+
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
164+
)
165+
else:
166+
self.img_mlp = nn.Sequential(
167+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
168+
nn.GELU(approximate="tanh"),
169+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
170+
)
153171

154172
if self.modulation:
155173
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
156174

157175
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
158-
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
176+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
159177

160178
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
161-
self.txt_mlp = nn.Sequential(
162-
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
163-
nn.GELU(approximate="tanh"),
164-
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
165-
)
179+
180+
if mlp_silu_act:
181+
self.txt_mlp = nn.Sequential(
182+
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
183+
SiLUActivation(),
184+
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
185+
)
186+
else:
187+
self.txt_mlp = nn.Sequential(
188+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
189+
nn.GELU(approximate="tanh"),
190+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
191+
)
192+
166193
self.flipped_img_txt = flipped_img_txt
167194

168195
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
@@ -246,6 +273,8 @@ def __init__(
246273
mlp_ratio: float = 4.0,
247274
qk_scale: float = None,
248275
modulation=True,
276+
mlp_silu_act=False,
277+
bias=True,
249278
dtype=None,
250279
device=None,
251280
operations=None
@@ -257,17 +286,24 @@ def __init__(
257286
self.scale = qk_scale or head_dim**-0.5
258287

259288
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
289+
290+
self.mlp_hidden_dim_first = self.mlp_hidden_dim
291+
if mlp_silu_act:
292+
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
293+
self.mlp_act = SiLUActivation()
294+
else:
295+
self.mlp_act = nn.GELU(approximate="tanh")
296+
260297
# qkv and mlp_in
261-
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
298+
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
262299
# proj and mlp_out
263-
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
300+
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
264301

265302
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
266303

267304
self.hidden_size = hidden_size
268305
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
269306

270-
self.mlp_act = nn.GELU(approximate="tanh")
271307
if modulation:
272308
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
273309
else:
@@ -279,7 +315,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
279315
else:
280316
mod = vec
281317

282-
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
318+
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
283319

284320
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
285321
del qkv
@@ -298,11 +334,11 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
298334

299335

300336
class LastLayer(nn.Module):
301-
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
337+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
302338
super().__init__()
303339
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
304-
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
305-
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
340+
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
341+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
306342

307343
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
308344
if vec.ndim == 2:

comfy/ldm/flux/model.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MLPEmbedder,
1616
SingleStreamBlock,
1717
timestep_embedding,
18+
Modulation
1819
)
1920

2021
@dataclass
@@ -33,6 +34,11 @@ class FluxParams:
3334
patch_size: int
3435
qkv_bias: bool
3536
guidance_embed: bool
37+
global_modulation: bool = False
38+
mlp_silu_act: bool = False
39+
ops_bias: bool = True
40+
default_ref_method: str = "offset"
41+
ref_index_scale: float = 1.0
3642

3743

3844
class Flux(nn.Module):
@@ -58,13 +64,17 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
5864
self.hidden_size = params.hidden_size
5965
self.num_heads = params.num_heads
6066
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
61-
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
62-
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
63-
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
67+
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
68+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
69+
if params.vec_in_dim is not None:
70+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
71+
else:
72+
self.vector_in = None
73+
6474
self.guidance_in = (
65-
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
75+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
6676
)
67-
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
77+
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
6878

6979
self.double_blocks = nn.ModuleList(
7080
[
@@ -73,6 +83,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
7383
self.num_heads,
7484
mlp_ratio=params.mlp_ratio,
7585
qkv_bias=params.qkv_bias,
86+
modulation=params.global_modulation is False,
87+
mlp_silu_act=params.mlp_silu_act,
88+
proj_bias=params.ops_bias,
7689
dtype=dtype, device=device, operations=operations
7790
)
7891
for _ in range(params.depth)
@@ -81,13 +94,30 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
8194

8295
self.single_blocks = nn.ModuleList(
8396
[
84-
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
97+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
8598
for _ in range(params.depth_single_blocks)
8699
]
87100
)
88101

89102
if final_layer:
90-
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
103+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
104+
105+
if params.global_modulation:
106+
self.double_stream_modulation_img = Modulation(
107+
self.hidden_size,
108+
double=True,
109+
bias=False,
110+
dtype=dtype, device=device, operations=operations
111+
)
112+
self.double_stream_modulation_txt = Modulation(
113+
self.hidden_size,
114+
double=True,
115+
bias=False,
116+
dtype=dtype, device=device, operations=operations
117+
)
118+
self.single_stream_modulation = Modulation(
119+
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
120+
)
91121

92122
def forward_orig(
93123
self,
@@ -103,9 +133,6 @@ def forward_orig(
103133
attn_mask: Tensor = None,
104134
) -> Tensor:
105135

106-
if y is None:
107-
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
108-
109136
patches = transformer_options.get("patches", {})
110137
patches_replace = transformer_options.get("patches_replace", {})
111138
if img.ndim != 3 or txt.ndim != 3:
@@ -118,9 +145,17 @@ def forward_orig(
118145
if guidance is not None:
119146
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
120147

121-
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
148+
if self.vector_in is not None:
149+
if y is None:
150+
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
151+
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
152+
122153
txt = self.txt_in(txt)
123154

155+
vec_orig = vec
156+
if self.params.global_modulation:
157+
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
158+
124159
if "post_input" in patches:
125160
for p in patches["post_input"]:
126161
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
@@ -177,6 +212,9 @@ def block_wrap(args):
177212

178213
img = torch.cat((txt, img), 1)
179214

215+
if self.params.global_modulation:
216+
vec, _ = self.single_stream_modulation(vec_orig)
217+
180218
for i, block in enumerate(self.single_blocks):
181219
if ("single_block", i) in blocks_replace:
182220
def block_wrap(args):
@@ -207,7 +245,7 @@ def block_wrap(args):
207245

208246
img = img[:, txt.shape[1] :, ...]
209247

210-
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
248+
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
211249
return img
212250

213251
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -234,10 +272,10 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}
234272
h_offset += rope_options.get("shift_y", 0.0)
235273
w_offset += rope_options.get("shift_x", 0.0)
236274

237-
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
275+
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
238276
img_ids[:, :, 0] = img_ids[:, :, 1] + index
239-
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
240-
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
277+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
278+
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
241279
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
242280

243281
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@@ -259,10 +297,10 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
259297
h = 0
260298
w = 0
261299
index = 0
262-
ref_latents_method = kwargs.get("ref_latents_method", "offset")
300+
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
263301
for ref in ref_latents:
264302
if ref_latents_method == "index":
265-
index += 1
303+
index += self.params.ref_index_scale
266304
h_offset = 0
267305
w_offset = 0
268306
elif ref_latents_method == "uxo":
@@ -286,7 +324,11 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
286324
img = torch.cat([img, kontext], dim=1)
287325
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
288326

289-
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
327+
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
328+
329+
if len(self.params.axes_dim) == 4: # Flux 2
330+
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
331+
290332
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
291333
out = out[:, :img_tokens]
292-
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
334+
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

0 commit comments

Comments
 (0)