Skip to content

Commit 221f1cd

Browse files
committed
🌱 Add duration to mask exporter, modify Tacotron2 and dataloader to accept
1 parent 5b15bb9 commit 221f1cd

File tree

3 files changed

+204
-4
lines changed

3 files changed

+204
-4
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import os
2+
import shutil
3+
from tqdm import tqdm
4+
import argparse
5+
6+
from scipy.ndimage import zoom
7+
from skimage.data import camera
8+
import numpy as np
9+
from scipy.spatial.distance import cdist
10+
11+
def safemkdir(dirn):
12+
if not os.path.isdir(dirn):
13+
os.mkdir(dirn)
14+
15+
from pathlib import Path
16+
17+
def duration_to_alignment(in_duration):
18+
total_len = np.sum(in_duration)
19+
num_chars = len(in_duration)
20+
21+
attention = np.zeros(shape=(num_chars,total_len),dtype=np.float32)
22+
y_offset = 0
23+
24+
for duration_idx, duration_val in enumerate(in_duration):
25+
for y_val in range(0,duration_val):
26+
attention[duration_idx][y_offset + y_val] = 1.0
27+
28+
y_offset += duration_val
29+
30+
return attention
31+
32+
33+
def rescale_alignment(in_alignment,in_targcharlen):
34+
current_x = in_alignment.shape[0]
35+
x_ratio = in_targcharlen / current_x
36+
pivot_points = []
37+
38+
zoomed = zoom(in_alignment,(x_ratio,1.0),mode="nearest")
39+
40+
for x_v in range(0,zoomed.shape[0]):
41+
for y_v in range(0,zoomed.shape[1]):
42+
val = zoomed[x_v][y_v]
43+
if val < 0.5:
44+
val = 0.0
45+
else:
46+
val = 1.0
47+
pivot_points.append( (x_v,y_v) )
48+
49+
zoomed[x_v][y_v] = val
50+
51+
52+
if zoomed.shape[0] != in_targcharlen:
53+
print("Zooming didn't rshape well, explicitly reshaping")
54+
zoomed.resize((in_targcharlen,in_alignment.shape[1]))
55+
56+
return zoomed, pivot_points
57+
58+
59+
def gather_dist(in_mtr,in_points):
60+
#initialize with known size for fast
61+
full_coords = [(0,0) for x in range(in_mtr.shape[0] * in_mtr.shape[1])]
62+
i = 0
63+
for x in range(0, in_mtr.shape[0]):
64+
for y in range(0, in_mtr.shape[1]):
65+
full_coords[i] = (x,y)
66+
i += 1
67+
68+
return cdist(full_coords, in_points,"euclidean")
69+
70+
71+
72+
73+
def create_guided(in_align,in_pvt,looseness):
74+
new_att = np.ones(in_align.shape,dtype=np.float32)
75+
# It is dramatically faster that we first gather all the points and calculate than do it manually
76+
# for each point in for loop
77+
dist_arr = gather_dist(in_align,in_pvt)
78+
# Scale looseness based on attention size. (addition works better than mul). Also divide by 100
79+
# because having user input 3.35 is nicer
80+
real_loose = (looseness / 100) * (new_att.shape[0] + new_att.shape[1])
81+
g_idx = 0
82+
for x in range(0, new_att.shape[0]):
83+
for y in range(0, new_att.shape[1]):
84+
min_point_idx = dist_arr[g_idx].argmin()
85+
86+
closest_pvt = in_pvt[min_point_idx]
87+
distance = dist_arr[g_idx][min_point_idx] / real_loose
88+
distance = np.power(distance,2)
89+
90+
g_idx += 1
91+
92+
new_att[x,y] = distance
93+
94+
return np.clip(new_att,0.0,1.0)
95+
96+
def get_pivot_points(in_att):
97+
ret_points = []
98+
for x in range(0, in_att.shape[0]):
99+
for y in range(0, in_att.shape[1]):
100+
if in_att[x,y] > 0.8:
101+
ret_points.append((x,y))
102+
return ret_points
103+
104+
def main():
105+
parser = argparse.ArgumentParser(description="Postprocess durations to become alignments")
106+
parser.add_argument(
107+
"--dump-dir",
108+
default="dump",
109+
type=str,
110+
help="Path of dump directory",
111+
)
112+
parser.add_argument(
113+
"--looseness",
114+
default=3.5,
115+
type=float,
116+
help="Looseness of the generated guided attention map. Lower values = tighter",
117+
)
118+
args = parser.parse_args()
119+
dump_dir = args.dump_dir
120+
dump_sets = ["train","valid"]
121+
122+
for d_set in dump_sets:
123+
full_fol = os.path.join(dump_dir,d_set)
124+
align_path = os.path.join(full_fol,"alignments")
125+
126+
ids_path = os.path.join(full_fol,"ids")
127+
durations_path = os.path.join(full_fol,"durations")
128+
129+
safemkdir(align_path)
130+
131+
for duration_fn in tqdm(os.listdir(durations_path)):
132+
if not ".npy" in duration_fn:
133+
continue
134+
135+
id_fn = duration_fn.replace("-durations","-ids")
136+
137+
id_path = os.path.join(ids_path,id_fn)
138+
duration_path = os.path.join(durations_path,duration_fn)
139+
140+
duration_arr = np.load(duration_path)
141+
id_arr = np.load(id_path)
142+
143+
id_true_size = len(id_arr)
144+
145+
align = duration_to_alignment(duration_arr)
146+
147+
if align.shape[0] != id_true_size:
148+
align, points = rescale_alignment(align,id_true_size)
149+
else:
150+
points = get_pivot_points(align)
151+
152+
if len(points) == 0:
153+
print("WARNING points are empty for",id_fn)
154+
155+
align = create_guided(align,points,args.looseness)
156+
157+
158+
align_fn = id_fn.replace("-ids","-alignment")
159+
align_full_fn = os.path.join(align_path,align_fn)
160+
161+
np.save(align_full_fn,align.astype("float32"))
162+
163+
164+
165+
166+
167+
if __name__ == "__main__":
168+
main()
169+

‎examples/tacotron2/tacotron_dataset.py‎

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
root_dir,
3636
charactor_query="*-ids.npy",
3737
mel_query="*-norm-feats.npy",
38+
align_query="",
3839
charactor_load_fn=np.load,
3940
mel_load_fn=np.load,
4041
mel_length_threshold=0,
@@ -52,6 +53,7 @@ def __init__(
5253
charactor_query (str): Query to find charactor files in root_dir.
5354
mel_query (str): Query to find feature files in root_dir.
5455
charactor_load_fn (func): Function to load charactor file.
56+
align_query (str): Query to find FAL files in root_dir. If empty, we use stock guided attention loss
5557
mel_load_fn (func): Function to load feature file.
5658
mel_length_threshold (int): Threshold to remove short feature files.
5759
reduction_factor (int): Reduction factor on Tacotron-2 paper.
@@ -67,6 +69,8 @@ def __init__(
6769
# find all of charactor and mel files.
6870
charactor_files = sorted(find_files(root_dir, charactor_query))
6971
mel_files = sorted(find_files(root_dir, mel_query))
72+
73+
7074
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
7175
char_lengths = [charactor_load_fn(f).shape[0] for f in charactor_files]
7276

@@ -76,6 +80,16 @@ def __init__(
7680
len(mel_files) == len(charactor_files) == len(mel_lengths)
7781
), f"Number of charactor, mel and duration files are different \
7882
({len(mel_files)} vs {len(charactor_files)} vs {len(mel_lengths)})."
83+
84+
self.align_files = []
85+
86+
if len(align_query) > 1:
87+
align_files = sorted(find_files(root_dir, align_query))
88+
assert len(align_files) == len(mel_files),f"Number of align files ({len(align_files)}) and mel files ({len(mel_files)}) are different"
89+
logging.info("Using FAL loss")
90+
self.align_files = align_files
91+
else:
92+
logging.info("Using guided attention loss")
7993

8094
if ".npy" in charactor_query:
8195
suffix = charactor_query[1:]
@@ -114,11 +128,13 @@ def generator(self, utt_ids):
114128
for i, utt_id in enumerate(utt_ids):
115129
mel_file = self.mel_files[i]
116130
charactor_file = self.charactor_files[i]
131+
align_file = self.align_files[i] if len(self.align_files) > 1 else ""
117132

118133
items = {
119134
"utt_ids": utt_id,
120135
"mel_files": mel_file,
121136
"charactor_files": charactor_file,
137+
"align_files": align_file,
122138
}
123139

124140
yield items
@@ -127,6 +143,8 @@ def generator(self, utt_ids):
127143
def _load_data(self, items):
128144
mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)
129145
charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)
146+
g_att = tf.numpy_function(np.load, [items["align_files"]], tf.float32) if len(self.align_files) > 1 else None
147+
130148
mel_length = len(mel)
131149
char_length = len(charactor)
132150
# padding mel to make its length is multiple of reduction factor.
@@ -149,6 +167,7 @@ def _load_data(self, items):
149167
"mel_gts": mel,
150168
"mel_lengths": mel_length,
151169
"real_mel_lengths": real_mel_length,
170+
"g_attentions": g_att,
152171
}
153172

154173
return items
@@ -187,10 +206,14 @@ def create(
187206
)
188207

189208
# calculate guided attention
190-
datasets = datasets.map(
191-
lambda items: self._guided_attention(items),
192-
tf.data.experimental.AUTOTUNE
193-
)
209+
if len(self.align_files) < 1:
210+
datasets = datasets.map(
211+
lambda items: self._guided_attention(items),
212+
tf.data.experimental.AUTOTUNE
213+
)
214+
215+
216+
194217

195218
datasets = datasets.filter(
196219
lambda x: x["mel_lengths"] > self.mel_length_threshold
@@ -249,6 +272,7 @@ def get_output_dtypes(self):
249272
"utt_ids": tf.string,
250273
"mel_files": tf.string,
251274
"charactor_files": tf.string,
275+
"align_files": tf.string,
252276
}
253277
return output_types
254278

‎examples/tacotron2/train_tacotron2.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ def main():
336336
nargs="?",
337337
help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",
338338
)
339+
parser.add_argument(
340+
"--use-fal", default=0, type=int, help="Use forced alignment guided attention loss or regular"
341+
)
339342
args = parser.parse_args()
340343

341344
# return strategy
@@ -347,6 +350,7 @@ def main():
347350

348351
args.mixed_precision = bool(args.mixed_precision)
349352
args.use_norm = bool(args.use_norm)
353+
args.use_fal = bool(args.use_fal)
350354

351355
# set logger
352356
if args.verbose > 1:
@@ -394,6 +398,7 @@ def main():
394398
if config["format"] == "npy":
395399
charactor_query = "*-ids.npy"
396400
mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy"
401+
align_query = "*-alignment.npy" if args.use_fal is True else ""
397402
charactor_load_fn = np.load
398403
mel_load_fn = np.load
399404
else:
@@ -409,6 +414,7 @@ def main():
409414
mel_length_threshold=mel_length_threshold,
410415
reduction_factor=config["tacotron2_params"]["reduction_factor"],
411416
use_fixed_shapes=config["use_fixed_shapes"],
417+
align_query=align_query,
412418
)
413419

414420
# update max_mel_length and max_char_length to config
@@ -438,6 +444,7 @@ def main():
438444
mel_length_threshold=mel_length_threshold,
439445
reduction_factor=config["tacotron2_params"]["reduction_factor"],
440446
use_fixed_shapes=False, # don't need apply fixed shape for evaluation.
447+
align_query=align_query,
441448
).create(
442449
is_shuffle=config["is_shuffle"],
443450
allow_cache=config["allow_cache"],

0 commit comments

Comments
 (0)