Skip to content

Commit 1986201

Browse files
committed
allow for one to customize objective to predict x0 or noise, default to predict epsilon given results from @Lamikins
1 parent 49e52c0 commit 1986201

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ def __init__(
493493
timesteps = 1000,
494494
use_ddim = False,
495495
noise_schedule = 'sigmoid',
496+
objective = 'eps',
496497
schedule_kwargs: dict = dict(),
497498
time_difference = 0.,
498499
train_prob_self_cond = 0.9
@@ -501,6 +502,9 @@ def __init__(
501502
self.model = model
502503
self.channels = self.model.channels
503504

505+
assert objective in {'x0', 'eps'}, 'objective must be either predict x0 or noise'
506+
self.objective = objective
507+
504508
self.image_size = image_size
505509

506510
if noise_schedule == "linear":
@@ -560,11 +564,7 @@ def ddpm_sample(self, shape, time_difference = None):
560564

561565
# get predicted x0
562566

563-
x_start, last_latents = self.model(img, noise_cond, x_start, last_latents, return_latents = True)
564-
565-
# clip x0
566-
567-
x_start.clamp_(-1., 1.)
567+
model_output, last_latents = self.model(img, noise_cond, x_start, last_latents, return_latents = True)
568568

569569
# get log(snr)
570570

@@ -577,6 +577,18 @@ def ddpm_sample(self, shape, time_difference = None):
577577
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
578578
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
579579

580+
# calculate x0 and noise
581+
582+
if self.objective == 'x0':
583+
x_start = model_output
584+
585+
elif self.objective == 'eps':
586+
x_start = (img - sigma * model_output) / alpha
587+
588+
# clip x0
589+
590+
x_start.clamp_(-1., 1.)
591+
580592
# derive posterior mean and variance
581593

582594
c = -expm1(log_snr - log_snr_next)
@@ -628,15 +640,27 @@ def ddim_sample(self, shape, time_difference = None):
628640

629641
# predict x0
630642

631-
x_start, last_latents = self.model(img, log_snr, x_start, last_latents, return_latents = True)
643+
model_output, last_latents = self.model(img, log_snr, x_start, last_latents, return_latents = True)
644+
645+
# calculate x0 and noise
646+
647+
if self.objective == 'x0':
648+
x_start = model_output
649+
650+
elif self.objective == 'eps':
651+
x_start = (img - sigma * model_output) / alpha
632652

633653
# clip x0
634654

635655
x_start.clamp_(-1., 1.)
636656

637657
# get predicted noise
638658

639-
pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
659+
if self.objective == 'x0':
660+
pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
661+
662+
elif self.objective == 'eps':
663+
pred_noise = model_output
640664

641665
# calculate x next
642666

@@ -687,7 +711,13 @@ def forward(self, img, *args, **kwargs):
687711

688712
pred = self.model(noised_img, noise_level, self_cond, self_latents)
689713

690-
return F.mse_loss(pred, img)
714+
if self.objective == 'x0':
715+
target = img
716+
717+
elif self.objective == 'eps':
718+
target = noise
719+
720+
return F.mse_loss(pred, target)
691721

692722
# dataset classes
693723

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.1',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)