Skip to content

Commit 31831e6

Browse files
Code refactor.
1 parent 88ceb28 commit 31831e6

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

comfy/ldm/flux/layers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ def __init__(
230230

231231
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
232232
mod, _ = self.modulation(vec)
233-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
234-
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
233+
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
235234

236235
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
237236
q, k = self.norm(q, k, v)

0 commit comments

Comments
 (0)