Skip to content

Commit d2c1566

Browse files
committed
needed for a scheme for better language following for robotics
1 parent b82c3a1 commit d2c1566

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def q_sample(self, x_start, t, noise=None):
695695
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
696696
)
697697

698-
def p_losses(self, x_start, t, noise = None, model_forward_kwargs: dict = dict()):
698+
def p_losses(self, x_start, t, noise = None, model_forward_kwargs: dict = dict(), return_reduced_loss = True):
699699
b = x_start.shape[0]
700700
n = x_start.shape[self.seq_index]
701701

@@ -737,7 +737,11 @@ def p_losses(self, x_start, t, noise = None, model_forward_kwargs: dict = dict()
737737
loss = reduce(loss, 'b ... -> b', 'mean')
738738

739739
loss = loss * extract(self.loss_weight, t, loss.shape)
740-
return loss.mean()
740+
741+
if return_reduced_loss:
742+
loss = loss.mean()
743+
744+
return loss
741745

742746
def forward(self, img, *args, **kwargs):
743747
b, n, device, seq_length, = img.shape[0], img.shape[self.seq_index], img.device, self.seq_length
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.2.2'
1+
__version__ = '2.2.3'

0 commit comments

Comments
 (0)