Skip to content

Commit 94c298f

Browse files
authored
flux: reduce VRAM usage (#10737)
Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22 for 1600x1600 on RTX5090.
1 parent 2fde959 commit 94c298f

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

comfy/ldm/flux/layers.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,39 +167,55 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
167167
img_modulated = self.img_norm1(img)
168168
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
169169
img_qkv = self.img_attn.qkv(img_modulated)
170+
del img_modulated
170171
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
172+
del img_qkv
171173
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
172174

173175
# prepare txt for attention
174176
txt_modulated = self.txt_norm1(txt)
175177
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
176178
txt_qkv = self.txt_attn.qkv(txt_modulated)
179+
del txt_modulated
177180
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
181+
del txt_qkv
178182
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
179183

180184
if self.flipped_img_txt:
185+
q = torch.cat((img_q, txt_q), dim=2)
186+
del img_q, txt_q
187+
k = torch.cat((img_k, txt_k), dim=2)
188+
del img_k, txt_k
189+
v = torch.cat((img_v, txt_v), dim=2)
190+
del img_v, txt_v
181191
# run actual attention
182-
attn = attention(torch.cat((img_q, txt_q), dim=2),
183-
torch.cat((img_k, txt_k), dim=2),
184-
torch.cat((img_v, txt_v), dim=2),
192+
attn = attention(q, k, v,
185193
pe=pe, mask=attn_mask, transformer_options=transformer_options)
194+
del q, k, v
186195

187196
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
188197
else:
198+
q = torch.cat((txt_q, img_q), dim=2)
199+
del txt_q, img_q
200+
k = torch.cat((txt_k, img_k), dim=2)
201+
del txt_k, img_k
202+
v = torch.cat((txt_v, img_v), dim=2)
203+
del txt_v, img_v
189204
# run actual attention
190-
attn = attention(torch.cat((txt_q, img_q), dim=2),
191-
torch.cat((txt_k, img_k), dim=2),
192-
torch.cat((txt_v, img_v), dim=2),
205+
attn = attention(q, k, v,
193206
pe=pe, mask=attn_mask, transformer_options=transformer_options)
207+
del q, k, v
194208

195209
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
196210

197211
# calculate the img bloks
198212
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
213+
del img_attn
199214
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
200215

201216
# calculate the txt bloks
202217
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
218+
del txt_attn
203219
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
204220

205221
if txt.dtype == torch.float16:
@@ -249,12 +265,15 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
249265
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
250266

251267
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
268+
del qkv
252269
q, k = self.norm(q, k, v)
253270

254271
# compute attention
255272
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
273+
del q, k, v
256274
# compute activation in mlp stream, cat again and run second linear layer
257-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
275+
mlp = self.mlp_act(mlp)
276+
output = self.linear2(torch.cat((attn, mlp), 2))
258277
x += apply_mod(output, mod.gate, None, modulation_dims)
259278
if x.dtype == torch.float16:
260279
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)

0 commit comments

Comments
 (0)