Skip to content

Commit b82c3a1

Browse files
committed
for diffusion steering for large behavioral model
1 parent d2ab4ac commit b82c3a1

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,12 @@ def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rede
574574

575575
return ModelPrediction(pred_noise, x_start)
576576

577-
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
578-
preds = self.model_predictions(x, t, x_self_cond)
577+
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True, model_forward_kwargs: dict = dict()):
578+
579+
if exists(x_self_cond):
580+
model_forward_kwargs = {**model_forward_kwargs, 'self_cond': x_self_cond}
581+
582+
preds = self.model_predictions(x, t, **model_forward_kwargs)
579583
x_start = preds.pred_x_start
580584

581585
if clip_denoised:
@@ -585,38 +589,44 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
585589
return model_mean, posterior_variance, posterior_log_variance, x_start
586590

587591
@torch.no_grad()
588-
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
592+
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True, model_forward_kwargs: dict = dict()):
589593
b, *_, device = *x.shape, x.device
590594
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
591-
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
595+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised, model_forward_kwargs = model_forward_kwargs)
592596
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
593597
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
594598
return pred_img, x_start
595599

596600
@torch.no_grad()
597-
def p_sample_loop(self, shape):
601+
def p_sample_loop(self, shape, return_noise = False, model_forward_kwargs: dict = dict()):
598602
batch, device = shape[0], self.betas.device
599603

600-
img = torch.randn(shape, device=device)
604+
noise = torch.randn(shape, device=device)
605+
img = noise
601606

602607
x_start = None
603608

604609
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
605610
self_cond = x_start if self.self_condition else None
606-
img, x_start = self.p_sample(img, t, self_cond)
611+
img, x_start = self.p_sample(img, t, self_cond, model_forward_kwargs = model_forward_kwargs)
607612

608613
img = self.unnormalize(img)
609-
return img
614+
615+
if not return_noise:
616+
return img
617+
618+
return img, noise
610619

611620
@torch.no_grad()
612-
def ddim_sample(self, shape, clip_denoised = True, model_forward_kwargs: dict = dict()):
621+
def ddim_sample(self, shape, clip_denoised = True, model_forward_kwargs: dict = dict(), return_noise = False):
613622
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
614623

615624
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
616625
times = list(reversed(times.int().tolist()))
617626
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
618627

619-
img = torch.randn(shape, device = device)
628+
noise = torch.randn(shape, device = device)
629+
img = noise
620630

621631
x_start = None
622632

@@ -642,15 +652,19 @@ def ddim_sample(self, shape, clip_denoised = True, model_forward_kwargs: dict =
642652
sigma * noise
643653

644654
img = self.unnormalize(img)
645-
return img
655+
656+
if not return_noise:
657+
return img
658+
659+
return img, noise
646660

647661
@torch.no_grad()
648-
def sample(self, batch_size = 16, model_forward_kwargs: dict = dict()):
662+
def sample(self, batch_size = 16, return_noise = False, model_forward_kwargs: dict = dict()):
649663
seq_length, channels = self.seq_length, self.channels
650664
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
651665

652666
shape = (batch_size, channels, seq_length) if self.channel_first else (batch_size, seq_length, channels)
653-
return sample_fn(shape, model_forward_kwargs = model_forward_kwargs)
667+
return sample_fn(shape, return_noise = return_noise, model_forward_kwargs = model_forward_kwargs)
654668

655669
@torch.no_grad()
656670
def interpolate(self, x1, x2, t = None, lam = 0.5):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.2.1'
1+
__version__ = '2.2.2'

0 commit comments

Comments
 (0)