@@ -49,6 +49,95 @@ def bark_scale(
4949 return res
5050
5151
52+ # copied code from
53+ # https://github.com/magenta/magenta/blob/main/magenta/models/gansynth/lib/spectral_ops.py
54+ _MEL_BREAK_FREQUENCY_HERTZ = 700.0
55+ _MEL_HIGH_FREQUENCY_Q = 1127.0
56+
57+
58+ def mel_to_hertz (mel_values : th .Tensor ) -> th .Tensor :
59+ return _MEL_BREAK_FREQUENCY_HERTZ * (
60+ th .exp (mel_values / _MEL_HIGH_FREQUENCY_Q ) - 1.0
61+ )
62+
63+
64+ def hertz_to_mel (frequencies_hertz : th .Tensor ) -> th .Tensor :
65+ return _MEL_HIGH_FREQUENCY_Q * th .log (
66+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ )
67+ )
68+
69+
70+ def linear_to_mel_weight_matrix (
71+ num_mel_bins : int = constants .N_FFT // 2 ,
72+ num_spectrogram_bins : int = constants .N_FFT // 2 ,
73+ sample_rate : int = constants .SAMPLE_RATE ,
74+ lower_edge_hertz : float = 125.0 ,
75+ upper_edge_hertz : float = 3800.0 ,
76+ ) -> th .Tensor :
77+
78+ # HTK excludes the spectrogram DC bin.
79+ bands_to_zero = 1
80+ nyquist_hertz = sample_rate / 2.0
81+ linear_frequencies = th .linspace (0.0 , nyquist_hertz , num_spectrogram_bins )[
82+ bands_to_zero :, None
83+ ]
84+ # spectrogram_bins_mel = hertz_to_mel(linear_frequencies)
85+
86+ # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
87+ # center of each band is the lower and upper edge of the adjacent bands.
88+ # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
89+ # num_mel_bins + 2 pieces.
90+ band_edges_mel = th .linspace (
91+ hertz_to_mel (th .tensor (lower_edge_hertz )).item (),
92+ hertz_to_mel (th .tensor (upper_edge_hertz )).item (),
93+ num_mel_bins + 2 ,
94+ )
95+
96+ lower_edge_mel = band_edges_mel [0 :- 2 ]
97+ center_mel = band_edges_mel [1 :- 1 ]
98+ upper_edge_mel = band_edges_mel [2 :]
99+
100+ freq_res = nyquist_hertz / float (num_spectrogram_bins )
101+ freq_th = 1.5 * freq_res
102+ for i in range (0 , num_mel_bins ):
103+ center_hz = mel_to_hertz (center_mel [i ])
104+ lower_hz = mel_to_hertz (lower_edge_mel [i ])
105+ upper_hz = mel_to_hertz (upper_edge_mel [i ])
106+ if upper_hz - lower_hz < freq_th :
107+ rhs = 0.5 * freq_th / (center_hz + _MEL_BREAK_FREQUENCY_HERTZ )
108+ dm = _MEL_HIGH_FREQUENCY_Q * th .log (rhs + th .sqrt (1.0 + rhs ** 2 ))
109+ lower_edge_mel [i ] = center_mel [i ] - dm
110+ upper_edge_mel [i ] = center_mel [i ] + dm
111+
112+ lower_edge_hz = mel_to_hertz (lower_edge_mel )[None , :]
113+ center_hz = mel_to_hertz (center_mel )[None , :]
114+ upper_edge_hz = mel_to_hertz (upper_edge_mel )[None , :]
115+
116+ # Calculate lower and upper slopes for every spectrogram bin.
117+ # Line segments are linear in the mel domain, not Hertz.
118+ lower_slopes = (linear_frequencies - lower_edge_hz ) / (
119+ center_hz - lower_edge_hz
120+ )
121+ upper_slopes = (upper_edge_hz - linear_frequencies ) / (
122+ upper_edge_hz - center_hz
123+ )
124+
125+ # Intersect the line segments with each other and zero.
126+ mel_weights_matrix = th .maximum (
127+ th .tensor (0.0 ), th .minimum (lower_slopes , upper_slopes )
128+ )
129+
130+ # Re-add the zeroed lower bins we sliced out above.
131+ # [freq, mel]
132+ mel_weights_matrix = th_f .pad (
133+ mel_weights_matrix , [bands_to_zero , 0 , 0 , 0 ], "constant"
134+ )
135+ return mel_weights_matrix
136+
137+
138+ # end of copied code
139+
140+
52141def wav_to_stft (
53142 wav_p : str ,
54143 n_per_seg : int = constants .N_FFT ,
@@ -130,6 +219,8 @@ def magnitude_phase_to_wav(
130219 sample_rate : int ,
131220 n_fft : int = constants .N_FFT ,
132221 stft_stride : int = constants .STFT_STRIDE ,
222+ threshold : float = 1.0 / 2 ** 8 ,
223+ magn_scale : float = 1.0 ,
133224) -> None :
134225 assert (
135226 len (magnitude_phase .size ()) == 4
@@ -151,7 +242,9 @@ def magnitude_phase_to_wav(
151242 phase = magnitude_phase_flattened [1 , :, :]
152243
153244 magnitude = (magnitude + 1.0 ) / 2.0
245+ magnitude [magnitude < threshold ] = 0.0
154246 magnitude = bark_scale (magnitude , "unscale" )
247+ magnitude = magnitude * magn_scale
155248
156249 phase = (phase + 1.0 ) / 2.0 * 2.0 * th .pi - th .pi
157250 phase = simpson (th .zeros (phase .size ()[0 ], 1 ), phase , 1 , 1.0 )
@@ -191,34 +284,30 @@ def create_dataset(
191284 elif not isdir (dataset_output_dir ):
192285 raise NotADirectoryError (dataset_output_dir )
193286
194- n_per_seg = constants .N_FFT
195- stride = constants .STFT_STRIDE
196-
197- nb_vec = constants .N_VEC
198-
199287 idx = 0
200288
201289 for wav_p in tqdm (w_p ):
202- complex_values = wav_to_stft (wav_p , n_per_seg = n_per_seg , stride = stride )
290+ complex_values = wav_to_stft (
291+ wav_p , n_per_seg = constants .N_FFT , stride = constants .STFT_STRIDE
292+ )
203293
204- if complex_values .size ()[1 ] < nb_vec :
294+ if complex_values .size ()[1 ] < constants . N_VEC :
205295 continue
206296
207297 magnitude , phase = stft_to_magnitude_phase (
208- complex_values , nb_vec = nb_vec
298+ complex_values , nb_vec = constants . N_VEC
209299 )
210300
211301 nb_sample = magnitude .size ()[0 ]
212302
213303 for s_idx in range (nb_sample ):
214- s_magnitude = magnitude [s_idx , :, :]
215- s_phase = phase [s_idx , :, :]
216-
217304 magnitude_phase_path = join (
218305 dataset_output_dir , f"magn_phase_{ idx } .pt"
219306 )
220307
221- magnitude_phase = th .stack ([s_magnitude , s_phase ], dim = 0 )
308+ magnitude_phase = th .stack (
309+ [magnitude [s_idx , :, :], phase [s_idx , :, :]], dim = 0
310+ )
222311 magnitude_phase = magnitude_phase .to (th .float )
223312
224313 th .save (magnitude_phase , magnitude_phase_path )
0 commit comments