-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathdata.py
More file actions
131 lines (112 loc) · 4.32 KB
/
data.py
File metadata and controls
131 lines (112 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# AudioDataLoader in mindspore.
# Adapted from https://github.com/kaituoxu/TasNet/blob/master/src/train.py
""" data """
import json
import os
import numpy as np
import mindaudio.data.io as io
class DatasetGenerator:
"""data"""
def __init__(self, json_dir, batch_size, sample_rate=8000, L=int(8000 * 0.005)):
"""
Args:
json_dir: directory including mix.json, s1.json and s2.json
xxx_infos is a list and each item is a tuple (wav_file, #samples)
"""
super(DatasetGenerator, self).__init__()
mix_json = os.path.join(json_dir, "mix_clean.json")
s1_json = os.path.join(json_dir, "s1.json")
s2_json = os.path.join(json_dir, "s2.json")
with open(mix_json, "r") as f:
mix_infos = json.load(f)
with open(s1_json, "r") as f:
s1_infos = json.load(f)
with open(s2_json, "r") as f:
s2_infos = json.load(f)
# sort it by #samples (impl bucket)
def sort(infos):
return sorted(infos, key=lambda info: (int(info[1]), info[0]), reverse=True)
sorted_mix_infos = sort(mix_infos)
sorted_s1_infos = sort(s1_infos)
sorted_s2_infos = sort(s2_infos)
mixture_pad = []
lens = []
source_pad = []
start = 0
while True:
end = min(len(sorted_mix_infos), start + batch_size)
meta = [
sorted_mix_infos[start:end],
sorted_s1_infos[start:end],
sorted_s2_infos[start:end],
sample_rate,
L,
]
mixtures_pad, ilens, sources_pad = self.sort_and_pad(meta)
for i in range(len(mixtures_pad)):
mixture_pad.append(mixtures_pad[i])
lens.append(ilens[i])
source_pad.append(sources_pad[i])
if end == len(sorted_mix_infos):
break
start = end
self.mixture = mixture_pad
self.len = lens
self.sources = source_pad
def __getitem__(self, index):
return self.mixture[index], self.len[index], self.sources[index]
def __len__(self):
return len(self.mixture)
def sort_and_pad(self, batch):
mixtures, sources = load_mixtures_and_sources(batch)
# get batch of lengths of input sequences
ilens = np.array([mix.shape[0] for mix in mixtures])
# perform padding and convert to tensor
mixtures_pad = pad_list([mix for mix in mixtures])
sources_pad = pad_list([s for s in sources])
# N x K x L x C -> N x C x K x L
sources_pad = sources_pad.transpose((0, 3, 1, 2))
return mixtures_pad, ilens, sources_pad
def load_mixtures_and_sources(batch):
"""
Returns:
mixtures: a list containing B items, each item is K x L np.ndarray
sources: a list containing B items, each item is K x L x C np.ndarray
K varies from item to item.
"""
mixtures, sources = [], []
mix_infos, s1_infos, s2_infos, sample_rate, L = batch
# for each utterance
for mix_info, s1_info, s2_info in zip(mix_infos, s1_infos, s2_infos):
mix_path = mix_info[0]
s1_path = s1_info[0]
s2_path = s2_info[0]
assert mix_info[1] == s1_info[1] and s1_info[1] == s2_info[1]
# read wav file
mix, _ = io.read(mix_path)
s1, _ = io.read(s1_path)
s2, _ = io.read(s2_path)
pad_len = 132800
pad_mix = np.concatenate([mix, np.zeros([pad_len - len(mix)], np.float32)])
pad_s1 = np.concatenate([s1, np.zeros([pad_len - len(s1)], np.float32)])
pad_s2 = np.concatenate([s2, np.zeros([pad_len - len(s2)], np.float32)])
# reshape
mix = np.reshape(pad_mix, [3320, L])
s1 = np.reshape(pad_s1, [3320, L])
s2 = np.reshape(pad_s2, [3320, L])
# merge s1 and s2
s = np.dstack((s1, s2)) # K x L x C, C = 2
mixtures.append(mix)
sources.append(s)
return mixtures, sources
def pad_list(xs):
n_batch = len(xs)
max_len = max(x.shape for x in xs)
if len(max_len) == 2:
pad = np.zeros((n_batch, max_len[0], max_len[1]), np.float32)
else:
pad = np.zeros((n_batch, max_len[0], max_len[1], max_len[2]), np.float32)
for i in range(n_batch):
temp = xs[i].shape
pad[i, : temp[0]] = xs[i]
return pad