Skip to content

Commit c2b90bd

Browse files
Merge pull request #49 from choiHkk/hotfix/datacollator_sampling_rate
[fix] Add fixed sampling rate to feature extractor
2 parents bdb0363 + 1d0cc01 commit c2b90bd

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

training/data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
3131
audios = [feature[self.audio_column_name]["array"] for feature in features]
3232
len_audio = [len(audio) for audio in audios]
3333

34-
batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length)
34+
# since resampling has already been performed in the 'load_multiple_datasets' function,
35+
# a fixed sampling_rate(44100hz) is passed to the feature_extractor.
36+
sampling_rate = self.feature_extractor.sampling_rate
37+
batch = self.feature_extractor(
38+
audios, sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
39+
)
3540
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
3641
return batch
3742

0 commit comments

Comments
 (0)