Skip to content

Commit f54576c

Browse files
committed
make proxy unoptional
1 parent 76a8b86 commit f54576c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

owl_vaes/models/dito.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,12 @@ def __init__(self, config):
9494
config.normalize_mu = True
9595
self.encoder = Encoder(config)
9696

97+
proxy_channels = getattr(config, "proxy_channels", config.channels)
98+
proxy_sample_size = getattr(config, "proxy_sample_size", config.sample_size)
99+
97100
decoder_config = deepcopy(config)
98-
decoder_config.channels = config.proxy_channels + config.latent_channels
99-
decoder_config.sample_size = config.proxy_sample_size
101+
decoder_config.channels = proxy_channels + config.latent_channels
102+
decoder_config.sample_size = proxy_sample_size
100103

101104
self.decoder = DiTODecoder(decoder_config)
102105

@@ -158,7 +161,7 @@ def forward(self, x, proxy = None):
158161
pred = (pred - noisy_x) / den
159162
else:
160163
noisy_z = z * (1. - tau_exp) + tau_exp * eps_z
161-
pred = self.decoder(noisy_z, z, ts)
164+
pred = self.decoder(noisy_x, noisy_z, ts)
162165
loss = F.mse_loss(pred, target)
163166

164167
return loss, z_original

0 commit comments

Comments
 (0)