Skip to content

Commit cca96a8

Browse files
Fix cosmos VAE failing with videos longer than 121 frames.
1 parent 619b8cd commit cca96a8

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

comfy/ldm/cosmos/vae.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch import nn
2020
from enum import Enum
21+
import math
2122

2223
from .cosmos_tokenizer.layers3d import (
2324
EncoderFactorized,
@@ -105,17 +106,23 @@ def encode(self, x):
105106
z, posteriors = self.distribution(moments)
106107
latent_ch = z.shape[1]
107108
latent_t = z.shape[2]
108-
dtype = z.dtype
109-
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
110-
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
109+
in_dtype = z.dtype
110+
mean = self.latent_mean.view(latent_ch, -1)
111+
std = self.latent_std.view(latent_ch, -1)
112+
113+
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
114+
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
111115
return ((z - mean) / std) * self.sigma_data
112116

113117
def decode(self, z):
114118
in_dtype = z.dtype
115119
latent_ch = z.shape[1]
116120
latent_t = z.shape[2]
117-
mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
118-
std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
121+
mean = self.latent_mean.view(latent_ch, -1)
122+
std = self.latent_std.view(latent_ch, -1)
123+
124+
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
125+
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
119126

120127
z = z / self.sigma_data
121128
z = z * std + mean

0 commit comments

Comments
 (0)