Skip to content

Commit 1d0cc01

Browse files
committed
[fix] Add fixed sampling rate to feature extractor
1 parent bdb0363 commit 1d0cc01

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

training/data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
3131
audios = [feature[self.audio_column_name]["array"] for feature in features]
3232
len_audio = [len(audio) for audio in audios]
3333

34-
batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length)
34+
# since resampling has already been performed in the 'load_multiple_datasets' function,
35+
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
36+
sampling_rate = self.feature_extractor.sampling_rate
37+
batch = self.feature_extractor(
38+
audios, sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
39+
)
3540
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
3641
return batch
3742

0 commit comments

Comments
 (0)