@@ -23,8 +23,6 @@ def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
23
23
else :
24
24
self .source_sample_rate = source_sample_rate
25
25
26
- # self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
27
-
28
26
self .transform = transforms .Compose ([
29
27
transforms .Normalize (0.5 , 0.5 ),
30
28
])
@@ -37,10 +35,6 @@ def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
37
35
self .scale_factor = 0.1786
38
36
self .shift_factor = - 1.9091
39
37
40
- def load_audio (self , audio_path ):
41
- audio , sr = torchaudio .load (audio_path )
42
- return audio , sr
43
-
44
38
def forward_mel (self , audios ):
45
39
mels = []
46
40
for i in range (len (audios )):
@@ -73,10 +67,8 @@ def encode(self, audios, audio_lengths=None, sr=None):
73
67
latent = self .dcae .encoder (mel .unsqueeze (0 ))
74
68
latents .append (latent )
75
69
latents = torch .cat (latents , dim = 0 )
76
- # latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
77
70
latents = (latents - self .shift_factor ) * self .scale_factor
78
71
return latents
79
- # return latents, latent_lengths
80
72
81
73
@torch .no_grad ()
82
74
def decode (self , latents , audio_lengths = None , sr = None ):
@@ -91,17 +83,14 @@ def decode(self, latents, audio_lengths=None, sr=None):
91
83
wav = self .vocoder .decode (mels [0 ]).squeeze (1 )
92
84
93
85
if sr is not None :
94
- # resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
95
86
wav = torchaudio .functional .resample (wav , 44100 , sr )
96
- # wav = resampler(wav)
97
87
else :
98
88
sr = 44100
99
89
pred_wavs .append (wav )
100
90
101
91
if audio_lengths is not None :
102
92
pred_wavs = [wav [:, :length ].cpu () for wav , length in zip (pred_wavs , audio_lengths )]
103
93
return torch .stack (pred_wavs )
104
- # return sr, pred_wavs
105
94
106
95
def forward (self , audios , audio_lengths = None , sr = None ):
107
96
latents , latent_lengths = self .encode (audios = audios , audio_lengths = audio_lengths , sr = sr )
0 commit comments