@@ -29,26 +29,21 @@ def discriminator_loss(self, disc_real_outputs, disc_generated_outputs):
2929 loss = 0
3030 rlosses = 0
3131 glosses = 0
32- r_losses = []
33- g_losses = []
3432
3533 for dr , dg in zip (disc_real_outputs , disc_generated_outputs ):
3634 r_loss = torch .mean ((1 - dr ) ** 2 )
3735 g_loss = torch .mean (dg ** 2 )
3836 loss += r_loss + g_loss
3937 rlosses += r_loss .item ()
4038 glosses += g_loss .item ()
41- r_losses .append (r_loss .item ())
42- g_losses .append (g_loss .item ())
4339
44- return loss , rlosses , glosses , r_losses , g_losses
40+ return loss , rlosses , glosses
4541
4642 def Dloss (self , Dfake , Dtrue ):
47-
4843 (Fmsd_out , _ ), (Fmpd_out , _ ) = Dfake
4944 (Tmsd_out , _ ), (Tmpd_out , _ ) = Dtrue
50- msdloss , msdrlosses , msdglosses , _ , _ = self .discriminator_loss (Tmsd_out , Fmsd_out )
51- mpdloss , mpdrlosses , mpdglosses , _ , _ = self .discriminator_loss (Tmpd_out , Fmpd_out )
45+ msdloss , msdrlosses , msdglosses = self .discriminator_loss (Tmsd_out , Fmsd_out )
46+ mpdloss , mpdrlosses , mpdglosses = self .discriminator_loss (Tmpd_out , Fmpd_out )
5247 loss = msdloss + mpdloss
5348 return loss , {'DmsdlossF' : msdglosses , 'DmsdlossT' : msdrlosses , 'DmpdlossT' : mpdrlosses ,
5449 'DmpdlossF' : mpdglosses }
@@ -57,55 +52,42 @@ def feature_loss(self, fmap_r, fmap_g):
5752 loss = 0
5853 for dr , dg in zip (fmap_r , fmap_g ):
5954 for rl , gl in zip (dr , dg ):
60- loss += torch . mean ( torch . abs ( rl - gl ) )
61-
55+ b = min ( rl . shape [ 0 ], gl . shape [ 0 ] )
56+ loss += torch . mean ( torch . abs ( rl [: b ] - gl [: b ]))
6257 return loss * 2
6358
6459 def GDloss (self , GDfake , GDtrue ):
6560 loss = 0
66- gen_losses = []
6761 msd_losses = 0
6862 mpd_losses = 0
63+
6964 (msd_out , Fmsd_feature ), (mpd_out , Fmpd_feature ) = GDfake
7065 (_ , Tmsd_feature ), (_ , Tmpd_feature ) = GDtrue
66+
7167 for dg in msd_out :
72- l = torch .mean ((1 - dg ) ** 2 )
73- gen_losses .append (l .item ())
74- # loss += l
75- msd_losses = l + msd_losses
76-
68+ msd_losses += torch .mean ((1 - dg ) ** 2 )
7769 for dg in mpd_out :
78- l = torch .mean ((1 - dg ) ** 2 )
79- gen_losses .append (l .item ())
80- # loss += l
81- mpd_losses = l + mpd_losses
82-
70+ mpd_losses += torch .mean ((1 - dg ) ** 2 )
71+
8372 msd_feature_loss = self .feature_loss (Tmsd_feature , Fmsd_feature )
8473 mpd_feature_loss = self .feature_loss (Tmpd_feature , Fmpd_feature )
85- # loss +=msd_feature_loss
86- # loss +=mpd_feature_loss
74+
8775 loss = msd_feature_loss + mpd_feature_loss + mpd_losses + msd_losses
88- # (msd_losses, mpd_losses), (msd_feature_loss, mpd_feature_loss), gen_losses
76+
8977 return loss , {'Gmsdloss' : msd_losses , 'Gmpdloss' : mpd_losses , 'Gmsd_feature_loss' : msd_feature_loss ,
9078 'Gmpd_feature_loss' : mpd_feature_loss }
9179
92- def Auxloss (self , Goutput , sample ):
93- Gmel = self .mel .dynamic_range_compression_torch (self .mel (Goutput ['audio' ].squeeze (1 )))
94- Rmel = self .mel .dynamic_range_compression_torch (self .mel (sample ['audio' ].squeeze (1 )))
80+ def Auxloss (self , Goutput , sample ):
81+ Gwav = Goutput ['audio' ].squeeze (1 )
82+ Rwav = sample ['audio' ].squeeze (1 )
83+ b = min (Gwav .shape [0 ], Rwav .shape [0 ])
84+ Gmel = self .mel .dynamic_range_compression_torch (self .mel (Gwav [: b ]))
85+ Rmel = self .mel .dynamic_range_compression_torch (self .mel (Rwav [: b ]))
9586 mel_loss = self .L1loss (Gmel , Rmel ) * self .lab_aux_mel_loss
9687 if self .use_stftloss :
97- sc_loss , mag_loss = self .stft .stft (Goutput [ 'audio' ]. squeeze ( 1 ), sample [ 'audio' ]. squeeze ( 1 ) )
88+ sc_loss , mag_loss = self .stft .stft (Gwav [: b ], Rwav [: b ] )
9889 stft_loss = (sc_loss + mag_loss ) * self .lab_aux_stft_loss
9990 loss = mel_loss + stft_loss
100- return loss , {'auxloss' : loss , 'auxloss_mel' : mel_loss , 'auxloss_stft' : stft_loss }
101- return mel_loss , {'auxloss' : mel_loss }
102-
103- # def Auxloss(self,Goutput, sample):
104- #
105- # Gmel=self.mel.dynamic_range_compression_torch(self.mel(Goutput['audio'].squeeze(1)))
106- # # Rmel=sample['mel']
107- # Rmel = self.mel.dynamic_range_compression_torch(self.mel(sample['audio'].squeeze(1)))
108- # sc_loss, mag_loss=self.stft.stft(Goutput['audio'].squeeze(1), sample['audio'].squeeze(1))
109- # loss=(sc_loss+ mag_loss)*self.labauxloss
110- # return loss,{'auxloss':loss,'auxloss_sc_loss':sc_loss,'auxloss_mag_loss':mag_loss}
111- #
91+ return loss , {'aux_mel_loss' : mel_loss , 'aux_stft_loss' : stft_loss }
92+ return mel_loss , {'aux_mel_loss' : mel_loss }
93+
0 commit comments