Skip to content

Commit 8786417

Browse files
authored
Merge pull request #2 from YdaodiG/loss_corrected
fix bug in loss function
2 parents a892ede + ee0517b commit 8786417

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/DOSE/learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def train_step(self, features):
138138

139139
audio = features['clean_speech']
140140
noisy = features['noisy_speech']
141-
141+
audio_orig = features['clean_speech'].clone()
142142

143143
N,T= audio.shape
144144
device = audio.device
@@ -159,7 +159,7 @@ def train_step(self, features):
159159

160160
noisy_audio = noise_scale_sqrt * audio + (1.0 - noise_scale)**0.5 * noise
161161
predicted = self.model(noisy_audio, t, noisy)
162-
loss = self.loss_fn(audio, predicted.squeeze(1))
162+
loss = self.loss_fn(audio_orig, predicted.squeeze(1))
163163

164164

165165
self.scaler.scale(loss).backward()

0 commit comments

Comments
 (0)