Skip to content

Commit 34b430d

Browse files
committed
some changes needed to use gaussian diffusion 1d for large behavioral models
1 parent 1d9d8df commit 34b430d

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,12 +423,18 @@ def __init__(
423423
objective = 'pred_noise',
424424
beta_schedule = 'cosine',
425425
ddim_sampling_eta = 0.,
426-
auto_normalize = True
426+
auto_normalize = True,
427+
channels = None,
428+
self_condition = None,
429+
channel_first = True
427430
):
428431
super().__init__()
429432
self.model = model
430-
self.channels = self.model.channels
431-
self.self_condition = self.model.self_condition
433+
self.channels = default(channels, lambda: self.model.channels)
434+
self.self_condition = default(self_condition, lambda: self.model.self_condition)
435+
436+
self.channel_first = channel_first
437+
self.seq_index = -2 if not channel_first else -1
432438

433439
self.seq_length = seq_length
434440

@@ -638,7 +644,9 @@ def ddim_sample(self, shape, clip_denoised = True):
638644
def sample(self, batch_size = 16):
639645
seq_length, channels = self.seq_length, self.channels
640646
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
641-
return sample_fn((batch_size, channels, seq_length))
647+
648+
shape = (batch_size, channels, seq_length) if self.channel_first else (batch_size, seq_length, channels)
649+
return sample_fn(shape)
642650

643651
@torch.no_grad()
644652
def interpolate(self, x1, x2, t = None, lam = 0.5):
@@ -669,8 +677,10 @@ def q_sample(self, x_start, t, noise=None):
669677
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
670678
)
671679

672-
def p_losses(self, x_start, t, noise = None):
673-
b, c, n = x_start.shape
680+
def p_losses(self, x_start, t, noise = None, model_forward_kwargs: dict = dict()):
681+
b = x_start.shape[0]
682+
n = x_start.shape[self.seq_index]
683+
674684
noise = default(noise, lambda: torch.randn_like(x_start))
675685

676686
# noise sample
@@ -687,9 +697,13 @@ def p_losses(self, x_start, t, noise = None):
687697
x_self_cond = self.model_predictions(x, t).pred_x_start
688698
x_self_cond.detach_()
689699

700+
model_forward_kwargs = {**model_forward_kwargs, 'self_cond': x_self_cond}
701+
702+
# model kwargs
703+
690704
# predict and take gradient step
691705

692-
model_out = self.model(x, t, x_self_cond)
706+
model_out = self.model(x, t, **model_forward_kwargs)
693707

694708
if self.objective == 'pred_noise':
695709
target = noise
@@ -708,7 +722,8 @@ def p_losses(self, x_start, t, noise = None):
708722
return loss.mean()
709723

710724
def forward(self, img, *args, **kwargs):
711-
b, c, n, device, seq_length, = *img.shape, img.device, self.seq_length
725+
b, n, device, seq_length, = img.shape[0], img.shape[self.seq_index], img.device, self.seq_length
726+
712727
assert n == seq_length, f'seq length must be {seq_length}'
713728
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
714729

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.1.1'
1+
__version__ = '2.2.0'

0 commit comments

Comments
 (0)