@@ -83,13 +83,23 @@ def __getitem__(self, index):
8383 if random .random () < self .key_aug_prob :
8484 audio = torch .from_numpy (data ['audio' ])
8585 speed = random .uniform (self .config ['aug_min' ], self .config ['aug_max' ])
86- audiox = wav_aug (audio , self .config ["hop_size" ], speed = speed )
87- mel = dynamic_range_compression_torch (self .mel_spec_transform (audiox [None ,:]))
88- f0 , uv = get_pitch (audio .numpy (), hparams = self .config , speed = speed , interp_uv = True , length = len (mel [0 ].T ))
86+ crop_mel_frames = int (np .ceil ((self .config ['crop_mel_frames' ] + 4 ) * speed ))
87+ samples_per_frame = self .config ['hop_size' ]
88+ crop_wav_samples = crop_mel_frames * samples_per_frame
89+ if crop_wav_samples < audio .shape [0 ]:
90+ return {'f0' : data ['f0' ], 'spectrogram' : data ['mel' ], 'audio' : data ['audio' ]}
91+ start = random .randint (0 , audio .shape [0 ] - 1 - crop_wav_samples )
92+ end = start + crop_wav_samples
93+ audio = audio [start :end ]
94+ f0 , uv = get_pitch (audio .numpy (), hparams = self .config , speed = speed , interp_uv = True , length = mel .shape [- 1 ])
8995 if f0 is None :
9096 return {'f0' : data ['f0' ], 'spectrogram' : data ['mel' ], 'audio' : data ['audio' ]}
91- f0 *= speed
92- return {'f0' : f0 , 'spectrogram' : mel [0 ].T .numpy (), 'audio' : audiox .numpy ()}
97+ audio_aug = wav_aug (audio , self .config ["hop_size" ], speed = speed )
98+ mel_aug = dynamic_range_compression_torch (self .mel_spec_transform (audio_aug [None ,:]))
99+ audio_aug = audio_aug [2 * samples_per_frame : - 2 * samples_per_frame ].numpy ()
100+ mel_aug = mel_aug [0 , :, 2 :- 2 ].T .numpy ()
101+ f0_aug = f0 [2 :- 2 ] * speed
102+ return {'f0' : f0_aug , 'spectrogram' : mel_aug , 'audio' : audio_aug }
93103
94104 else :
95105 return {'f0' : data ['f0' ], 'spectrogram' : data ['mel' ], 'audio' : data ['audio' ]}
@@ -107,13 +117,15 @@ def collater(self, minibatch):
107117 for record in minibatch :
108118
109119 # Filter out records that aren't long enough.
110- if len ( record ['spectrogram' ]) <= crop_mel_frames :
120+ if record ['spectrogram' ]. shape [ 0 ] < crop_mel_frames :
111121 del record ['spectrogram' ]
112122 del record ['audio' ]
113123 del record ['f0' ]
114124 continue
115-
116- start = random .randint (0 , record ['spectrogram' ].shape [0 ] - 1 - crop_mel_frames )
125+ elif record ['spectrogram' ].shape [0 ] == crop_mel_frames :
126+ start = 0
127+ else :
128+ start = random .randint (0 , record ['spectrogram' ].shape [0 ] - 1 - crop_mel_frames )
117129 end = start + crop_mel_frames
118130 if self .infer :
119131 record ['spectrogram' ] = record ['spectrogram' ].T
0 commit comments