Skip to content

Commit 4965c0e

Browse files
authored
WAN: Fix cache VRAM leak on error (#10141)
If this suffers an exception (such as a VRAM oom) it will leave the encode() and decode() methods which skips the cleanup of the WAN feature cache. The comfy node cache then ultimately keeps a reference this object which is in turn reffing large tensors from the failed execution. The feature cache is currently setup at a class variable on the encoder/decoder however, the encode and decode functions always clear it on both entry and exit of normal execution. Its likely the design intent is this is usable as a streaming encoder where the input comes in batches, however the functions as they are today don't support that. So simplify by bringing the cache back to local variable, so that if it does VRAM OOM the cache itself is properly garbage when the encode()/decode() functions dissappear from the stack.
1 parent 911331c commit 4965c0e

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

comfy/ldm/wan/vae.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -468,55 +468,46 @@ def __init__(self,
468468
attn_scales, self.temperal_upsample, dropout)
469469

470470
def encode(self, x):
471-
self.clear_cache()
471+
conv_idx = [0]
472+
feat_map = [None] * count_conv3d(self.decoder)
472473
## cache
473474
t = x.shape[2]
474475
iter_ = 1 + (t - 1) // 4
475476
## 对encode输入的x,按时间拆分为1、4、4、4....
476477
for i in range(iter_):
477-
self._enc_conv_idx = [0]
478+
conv_idx = [0]
478479
if i == 0:
479480
out = self.encoder(
480481
x[:, :, :1, :, :],
481-
feat_cache=self._enc_feat_map,
482-
feat_idx=self._enc_conv_idx)
482+
feat_cache=feat_map,
483+
feat_idx=conv_idx)
483484
else:
484485
out_ = self.encoder(
485486
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
486-
feat_cache=self._enc_feat_map,
487-
feat_idx=self._enc_conv_idx)
487+
feat_cache=feat_map,
488+
feat_idx=conv_idx)
488489
out = torch.cat([out, out_], 2)
489490
mu, log_var = self.conv1(out).chunk(2, dim=1)
490-
self.clear_cache()
491491
return mu
492492

493493
def decode(self, z):
494-
self.clear_cache()
494+
conv_idx = [0]
495+
feat_map = [None] * count_conv3d(self.decoder)
495496
# z: [b,c,t,h,w]
496497

497498
iter_ = z.shape[2]
498499
x = self.conv2(z)
499500
for i in range(iter_):
500-
self._conv_idx = [0]
501+
conv_idx = [0]
501502
if i == 0:
502503
out = self.decoder(
503504
x[:, :, i:i + 1, :, :],
504-
feat_cache=self._feat_map,
505-
feat_idx=self._conv_idx)
505+
feat_cache=feat_map,
506+
feat_idx=conv_idx)
506507
else:
507508
out_ = self.decoder(
508509
x[:, :, i:i + 1, :, :],
509-
feat_cache=self._feat_map,
510-
feat_idx=self._conv_idx)
510+
feat_cache=feat_map,
511+
feat_idx=conv_idx)
511512
out = torch.cat([out, out_], 2)
512-
self.clear_cache()
513513
return out
514-
515-
def clear_cache(self):
516-
self._conv_num = count_conv3d(self.decoder)
517-
self._conv_idx = [0]
518-
self._feat_map = [None] * self._conv_num
519-
#cache encode
520-
self._enc_conv_num = count_conv3d(self.encoder)
521-
self._enc_conv_idx = [0]
522-
self._enc_feat_map = [None] * self._enc_conv_num

0 commit comments

Comments
 (0)