88import numpy as np
99from scipy .spatial .distance import cdist
1010
11+
1112def safemkdir (dirn ):
1213 if not os .path .isdir (dirn ):
1314 os .mkdir (dirn )
14-
15+
16+
1517from pathlib import Path
1618
19+
1720def duration_to_alignment (in_duration ):
1821 total_len = np .sum (in_duration )
1922 num_chars = len (in_duration )
2023
21- attention = np .zeros (shape = (num_chars ,total_len ),dtype = np .float32 )
24+ attention = np .zeros (shape = (num_chars , total_len ), dtype = np .float32 )
2225 y_offset = 0
2326
2427 for duration_idx , duration_val in enumerate (in_duration ):
25- for y_val in range (0 ,duration_val ):
28+ for y_val in range (0 , duration_val ):
2629 attention [duration_idx ][y_offset + y_val ] = 1.0
27-
30+
2831 y_offset += duration_val
29-
32+
3033 return attention
3134
3235
33- def rescale_alignment (in_alignment ,in_targcharlen ):
36+ def rescale_alignment (in_alignment , in_targcharlen ):
3437 current_x = in_alignment .shape [0 ]
3538 x_ratio = in_targcharlen / current_x
3639 pivot_points = []
37-
38- zoomed = zoom (in_alignment ,(x_ratio ,1.0 ),mode = "nearest" )
3940
40- for x_v in range (0 ,zoomed .shape [0 ]):
41- for y_v in range (0 ,zoomed .shape [1 ]):
41+ zoomed = zoom (in_alignment , (x_ratio , 1.0 ), mode = "nearest" )
42+
43+ for x_v in range (0 , zoomed .shape [0 ]):
44+ for y_v in range (0 , zoomed .shape [1 ]):
4245 val = zoomed [x_v ][y_v ]
4346 if val < 0.5 :
4447 val = 0.0
4548 else :
4649 val = 1.0
47- pivot_points .append ( (x_v ,y_v ) )
50+ pivot_points .append ((x_v , y_v ))
4851
4952 zoomed [x_v ][y_v ] = val
50-
51-
53+
5254 if zoomed .shape [0 ] != in_targcharlen :
5355 print ("Zooming didn't rshape well, explicitly reshaping" )
54- zoomed .resize ((in_targcharlen ,in_alignment .shape [1 ]))
56+ zoomed .resize ((in_targcharlen , in_alignment .shape [1 ]))
5557
5658 return zoomed , pivot_points
5759
5860
59- def gather_dist (in_mtr ,in_points ):
60- #initialize with known size for fast
61- full_coords = [(0 ,0 ) for x in range (in_mtr .shape [0 ] * in_mtr .shape [1 ])]
61+ def gather_dist (in_mtr , in_points ):
62+ # initialize with known size for fast
63+ full_coords = [(0 , 0 ) for x in range (in_mtr .shape [0 ] * in_mtr .shape [1 ])]
6264 i = 0
6365 for x in range (0 , in_mtr .shape [0 ]):
6466 for y in range (0 , in_mtr .shape [1 ]):
65- full_coords [i ] = (x ,y )
67+ full_coords [i ] = (x , y )
6668 i += 1
67-
68- return cdist (full_coords , in_points ,"euclidean" )
69-
70-
69+
70+ return cdist (full_coords , in_points , "euclidean" )
7171
7272
73- def create_guided (in_align ,in_pvt ,looseness ):
74- new_att = np .ones (in_align .shape ,dtype = np .float32 )
73+ def create_guided (in_align , in_pvt , looseness ):
74+ new_att = np .ones (in_align .shape , dtype = np .float32 )
7575 # It is dramatically faster that we first gather all the points and calculate than do it manually
7676 # for each point in for loop
77- dist_arr = gather_dist (in_align ,in_pvt )
77+ dist_arr = gather_dist (in_align , in_pvt )
7878 # Scale looseness based on attention size. (addition works better than mul). Also divide by 100
7979 # because having user input 3.35 is nicer
8080 real_loose = (looseness / 100 ) * (new_att .shape [0 ] + new_att .shape [1 ])
@@ -85,57 +85,61 @@ def create_guided(in_align,in_pvt,looseness):
8585
8686 closest_pvt = in_pvt [min_point_idx ]
8787 distance = dist_arr [g_idx ][min_point_idx ] / real_loose
88- distance = np .power (distance ,2 )
88+ distance = np .power (distance , 2 )
8989
9090 g_idx += 1
91-
92- new_att [x ,y ] = distance
9391
94- return np .clip (new_att ,0.0 ,1.0 )
92+ new_att [x , y ] = distance
93+
94+ return np .clip (new_att , 0.0 , 1.0 )
95+
9596
9697def get_pivot_points (in_att ):
9798 ret_points = []
9899 for x in range (0 , in_att .shape [0 ]):
99100 for y in range (0 , in_att .shape [1 ]):
100- if in_att [x ,y ] > 0.8 :
101- ret_points .append ((x ,y ))
101+ if in_att [x , y ] > 0.8 :
102+ ret_points .append ((x , y ))
102103 return ret_points
103104
105+
104106def main ():
105- parser = argparse .ArgumentParser (description = "Postprocess durations to become alignments" )
107+ parser = argparse .ArgumentParser (
108+ description = "Postprocess durations to become alignments"
109+ )
106110 parser .add_argument (
107- "--dump-dir" ,
108- default = "dump" ,
109- type = str ,
110- help = "Path of dump directory" ,
111+ "--dump-dir" ,
112+ default = "dump" ,
113+ type = str ,
114+ help = "Path of dump directory" ,
111115 )
112116 parser .add_argument (
113- "--looseness" ,
114- default = 3.5 ,
115- type = float ,
116- help = "Looseness of the generated guided attention map. Lower values = tighter" ,
117+ "--looseness" ,
118+ default = 3.5 ,
119+ type = float ,
120+ help = "Looseness of the generated guided attention map. Lower values = tighter" ,
117121 )
118122 args = parser .parse_args ()
119123 dump_dir = args .dump_dir
120- dump_sets = ["train" ,"valid" ]
124+ dump_sets = ["train" , "valid" ]
121125
122126 for d_set in dump_sets :
123- full_fol = os .path .join (dump_dir ,d_set )
124- align_path = os .path .join (full_fol ,"alignments" )
127+ full_fol = os .path .join (dump_dir , d_set )
128+ align_path = os .path .join (full_fol , "alignments" )
125129
126- ids_path = os .path .join (full_fol ,"ids" )
127- durations_path = os .path .join (full_fol ,"durations" )
130+ ids_path = os .path .join (full_fol , "ids" )
131+ durations_path = os .path .join (full_fol , "durations" )
128132
129133 safemkdir (align_path )
130134
131135 for duration_fn in tqdm (os .listdir (durations_path )):
132136 if not ".npy" in duration_fn :
133- continue
134-
135- id_fn = duration_fn .replace ("-durations" ,"-ids" )
137+ continue
136138
137- id_path = os .path .join (ids_path ,id_fn )
138- duration_path = os .path .join (durations_path ,duration_fn )
139+ id_fn = duration_fn .replace ("-durations" , "-ids" )
140+
141+ id_path = os .path .join (ids_path , id_fn )
142+ duration_path = os .path .join (durations_path , duration_fn )
139143
140144 duration_arr = np .load (duration_path )
141145 id_arr = np .load (id_path )
@@ -145,25 +149,20 @@ def main():
145149 align = duration_to_alignment (duration_arr )
146150
147151 if align .shape [0 ] != id_true_size :
148- align , points = rescale_alignment (align ,id_true_size )
152+ align , points = rescale_alignment (align , id_true_size )
149153 else :
150154 points = get_pivot_points (align )
151-
152- if len (points ) == 0 :
153- print ("WARNING points are empty for" ,id_fn )
154155
155- align = create_guided (align ,points ,args .looseness )
156+ if len (points ) == 0 :
157+ print ("WARNING points are empty for" , id_fn )
156158
157-
158- align_fn = id_fn .replace ("-ids" ,"-alignment" )
159- align_full_fn = os .path .join (align_path ,align_fn )
160-
161- np .save (align_full_fn ,align .astype ("float32" ))
162-
159+ align = create_guided (align , points , args .looseness )
163160
161+ align_fn = id_fn .replace ("-ids" , "-alignment" )
162+ align_full_fn = os .path .join (align_path , align_fn )
164163
164+ np .save (align_full_fn , align .astype ("float32" ))
165165
166166
167167if __name__ == "__main__" :
168168 main ()
169-
0 commit comments