|
18 | 18 | import torch |
19 | 19 | from torch import nn |
20 | 20 | from enum import Enum |
| 21 | +import math |
21 | 22 |
|
22 | 23 | from .cosmos_tokenizer.layers3d import ( |
23 | 24 | EncoderFactorized, |
@@ -105,17 +106,23 @@ def encode(self, x): |
105 | 106 | z, posteriors = self.distribution(moments) |
106 | 107 | latent_ch = z.shape[1] |
107 | 108 | latent_t = z.shape[2] |
108 | | - dtype = z.dtype |
109 | | - mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) |
110 | | - std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) |
| 109 | + in_dtype = z.dtype |
| 110 | + mean = self.latent_mean.view(latent_ch, -1) |
| 111 | + std = self.latent_std.view(latent_ch, -1) |
| 112 | + |
| 113 | + mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
| 114 | + std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
111 | 115 | return ((z - mean) / std) * self.sigma_data |
112 | 116 |
|
113 | 117 | def decode(self, z): |
114 | 118 | in_dtype = z.dtype |
115 | 119 | latent_ch = z.shape[1] |
116 | 120 | latent_t = z.shape[2] |
117 | | - mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
118 | | - std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
| 121 | + mean = self.latent_mean.view(latent_ch, -1) |
| 122 | + std = self.latent_std.view(latent_ch, -1) |
| 123 | + |
| 124 | + mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
| 125 | + std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) |
119 | 126 |
|
120 | 127 | z = z / self.sigma_data |
121 | 128 | z = z * std + mean |
|
0 commit comments