Skip to content

Commit f351148

Browse files
committed
adjustment to returning unreduced loss in gaussian 1d
1 parent d2c1566 commit f351148

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,14 +734,15 @@ def p_losses(self, x_start, t, noise = None, model_forward_kwargs: dict = dict()
734734
raise ValueError(f'unknown objective {self.objective}')
735735

736736
loss = F.mse_loss(model_out, target, reduction = 'none')
737+
738+
if not return_reduced_loss:
739+
return loss * extract(self.loss_weight, t, loss.shape)
740+
737741
loss = reduce(loss, 'b ... -> b', 'mean')
738742

739743
loss = loss * extract(self.loss_weight, t, loss.shape)
740744

741-
if return_reduced_loss:
742-
loss = loss.mean()
743-
744-
return loss
745+
return loss.mean()
745746

746747
def forward(self, img, *args, **kwargs):
747748
b, n, device, seq_length, = img.shape[0], img.shape[self.seq_index], img.device, self.seq_length
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.2.3'
1+
__version__ = '2.2.4'

0 commit comments

Comments
 (0)