Skip to content

Commit 59e27bd

Browse files
committed
📈Reformat with black
1 parent 2883b6e commit 59e27bd

File tree

7 files changed

+126
-119
lines changed

7 files changed

+126
-119
lines changed

examples/fastspeech2/extractfs_postnets.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ def main():
9696
os.makedirs(args.outdir)
9797

9898
# load config
99-
100-
outdpost = os.path.join(args.outdir,"postnets")
101-
99+
100+
outdpost = os.path.join(args.outdir, "postnets")
101+
102102
if not os.path.exists(outdpost):
103103
os.makedirs(outdpost)
104-
104+
105105
with open(args.config) as f:
106106
config = yaml.load(f, Loader=yaml.Loader)
107107
config.update(vars(args))
@@ -118,7 +118,9 @@ def main():
118118
charactor_query=char_query,
119119
charactor_load_fn=char_load_fn,
120120
)
121-
dataset = dataset.create(batch_size=1) # force batch size to 1 otherwise it may miss certain files
121+
dataset = dataset.create(
122+
batch_size=1
123+
) # force batch size to 1 otherwise it may miss certain files
122124

123125
# define model and load checkpoint
124126
fastspeech2 = TFFastSpeech2(
@@ -134,8 +136,9 @@ def main():
134136
mel_lens = data["mel_lengths"]
135137

136138
# fastspeech inference.
137-
masked_mel_before, masked_mel_after , duration_outputs, _, _ = fastspeech2(**data,training=True)
138-
139+
masked_mel_before, masked_mel_after, duration_outputs, _, _ = fastspeech2(
140+
**data, training=True
141+
)
139142

140143
# convert to numpy
141144
masked_mel_befores = masked_mel_before.numpy()

examples/multiband_melgan_hf/train_multiband_melgan_hf.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer, pqmf):
115115

116116
def compute_per_example_generator_losses(self, batch, outputs):
117117
"""Compute per example generator losses and return dict_metrics_losses
118-
Note that all element of the loss MUST has a shape [batch_size] and
118+
Note that all element of the loss MUST has a shape [batch_size] and
119119
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
120120
121121
Args:
122122
batch: dictionary batch input return from dataloader
123123
outputs: outputs of the model
124-
124+
125125
Returns:
126126
per_example_losses: per example losses for each GPU, shape [B]
127127
dict_metrics_losses: dictionary loss.
@@ -172,7 +172,9 @@ def compute_per_example_generator_losses(self, batch, outputs):
172172
adv_loss /= i + 1
173173
gen_loss += self.config["lambda_adv"] * adv_loss
174174

175-
dict_metrics_losses.update({"adversarial_loss": adv_loss},)
175+
dict_metrics_losses.update(
176+
{"adversarial_loss": adv_loss},
177+
)
176178

177179
dict_metrics_losses.update({"gen_loss": gen_loss})
178180
dict_metrics_losses.update({"subband_spectral_convergence_loss": sub_sc_loss})
@@ -185,13 +187,13 @@ def compute_per_example_generator_losses(self, batch, outputs):
185187

186188
def compute_per_example_discriminator_losses(self, batch, gen_outputs):
187189
"""Compute per example discriminator losses and return dict_metrics_losses
188-
Note that all element of the loss MUST has a shape [batch_size] and
190+
Note that all element of the loss MUST has a shape [batch_size] and
189191
the keys of dict_metrics_losses MUST be in self.list_metrics_name.
190192
191193
Args:
192194
batch: dictionary batch input return from dataloader
193195
outputs: outputs of the model
194-
196+
195197
Returns:
196198
per_example_losses: per example losses for each GPU, shape [B]
197199
dict_metrics_losses: dictionary loss.
@@ -400,7 +402,6 @@ def main():
400402
else:
401403
raise ValueError("Only npy are supported.")
402404

403-
404405
if args.postnets is True:
405406
mel_query = "*-postnet.npy"
406407
logging.info("Using postnets")
@@ -553,4 +554,3 @@ def main():
553554

554555
if __name__ == "__main__":
555556
main()
556-

examples/tacotron2/export_align.py

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,73 +8,73 @@
88
import numpy as np
99
from scipy.spatial.distance import cdist
1010

11+
1112
def safemkdir(dirn):
1213
if not os.path.isdir(dirn):
1314
os.mkdir(dirn)
14-
15+
16+
1517
from pathlib import Path
1618

19+
1720
def duration_to_alignment(in_duration):
1821
total_len = np.sum(in_duration)
1922
num_chars = len(in_duration)
2023

21-
attention = np.zeros(shape=(num_chars,total_len),dtype=np.float32)
24+
attention = np.zeros(shape=(num_chars, total_len), dtype=np.float32)
2225
y_offset = 0
2326

2427
for duration_idx, duration_val in enumerate(in_duration):
25-
for y_val in range(0,duration_val):
28+
for y_val in range(0, duration_val):
2629
attention[duration_idx][y_offset + y_val] = 1.0
27-
30+
2831
y_offset += duration_val
29-
32+
3033
return attention
3134

3235

33-
def rescale_alignment(in_alignment,in_targcharlen):
36+
def rescale_alignment(in_alignment, in_targcharlen):
3437
current_x = in_alignment.shape[0]
3538
x_ratio = in_targcharlen / current_x
3639
pivot_points = []
37-
38-
zoomed = zoom(in_alignment,(x_ratio,1.0),mode="nearest")
3940

40-
for x_v in range(0,zoomed.shape[0]):
41-
for y_v in range(0,zoomed.shape[1]):
41+
zoomed = zoom(in_alignment, (x_ratio, 1.0), mode="nearest")
42+
43+
for x_v in range(0, zoomed.shape[0]):
44+
for y_v in range(0, zoomed.shape[1]):
4245
val = zoomed[x_v][y_v]
4346
if val < 0.5:
4447
val = 0.0
4548
else:
4649
val = 1.0
47-
pivot_points.append( (x_v,y_v) )
50+
pivot_points.append((x_v, y_v))
4851

4952
zoomed[x_v][y_v] = val
50-
51-
53+
5254
if zoomed.shape[0] != in_targcharlen:
5355
print("Zooming didn't rshape well, explicitly reshaping")
54-
zoomed.resize((in_targcharlen,in_alignment.shape[1]))
56+
zoomed.resize((in_targcharlen, in_alignment.shape[1]))
5557

5658
return zoomed, pivot_points
5759

5860

59-
def gather_dist(in_mtr,in_points):
60-
#initialize with known size for fast
61-
full_coords = [(0,0) for x in range(in_mtr.shape[0] * in_mtr.shape[1])]
61+
def gather_dist(in_mtr, in_points):
62+
# initialize with known size for fast
63+
full_coords = [(0, 0) for x in range(in_mtr.shape[0] * in_mtr.shape[1])]
6264
i = 0
6365
for x in range(0, in_mtr.shape[0]):
6466
for y in range(0, in_mtr.shape[1]):
65-
full_coords[i] = (x,y)
67+
full_coords[i] = (x, y)
6668
i += 1
67-
68-
return cdist(full_coords, in_points,"euclidean")
69-
70-
69+
70+
return cdist(full_coords, in_points, "euclidean")
7171

7272

73-
def create_guided(in_align,in_pvt,looseness):
74-
new_att = np.ones(in_align.shape,dtype=np.float32)
73+
def create_guided(in_align, in_pvt, looseness):
74+
new_att = np.ones(in_align.shape, dtype=np.float32)
7575
# It is dramatically faster that we first gather all the points and calculate than do it manually
7676
# for each point in for loop
77-
dist_arr = gather_dist(in_align,in_pvt)
77+
dist_arr = gather_dist(in_align, in_pvt)
7878
# Scale looseness based on attention size. (addition works better than mul). Also divide by 100
7979
# because having user input 3.35 is nicer
8080
real_loose = (looseness / 100) * (new_att.shape[0] + new_att.shape[1])
@@ -85,57 +85,61 @@ def create_guided(in_align,in_pvt,looseness):
8585

8686
closest_pvt = in_pvt[min_point_idx]
8787
distance = dist_arr[g_idx][min_point_idx] / real_loose
88-
distance = np.power(distance,2)
88+
distance = np.power(distance, 2)
8989

9090
g_idx += 1
91-
92-
new_att[x,y] = distance
9391

94-
return np.clip(new_att,0.0,1.0)
92+
new_att[x, y] = distance
93+
94+
return np.clip(new_att, 0.0, 1.0)
95+
9596

9697
def get_pivot_points(in_att):
9798
ret_points = []
9899
for x in range(0, in_att.shape[0]):
99100
for y in range(0, in_att.shape[1]):
100-
if in_att[x,y] > 0.8:
101-
ret_points.append((x,y))
101+
if in_att[x, y] > 0.8:
102+
ret_points.append((x, y))
102103
return ret_points
103104

105+
104106
def main():
105-
parser = argparse.ArgumentParser(description="Postprocess durations to become alignments")
107+
parser = argparse.ArgumentParser(
108+
description="Postprocess durations to become alignments"
109+
)
106110
parser.add_argument(
107-
"--dump-dir",
108-
default="dump",
109-
type=str,
110-
help="Path of dump directory",
111+
"--dump-dir",
112+
default="dump",
113+
type=str,
114+
help="Path of dump directory",
111115
)
112116
parser.add_argument(
113-
"--looseness",
114-
default=3.5,
115-
type=float,
116-
help="Looseness of the generated guided attention map. Lower values = tighter",
117+
"--looseness",
118+
default=3.5,
119+
type=float,
120+
help="Looseness of the generated guided attention map. Lower values = tighter",
117121
)
118122
args = parser.parse_args()
119123
dump_dir = args.dump_dir
120-
dump_sets = ["train","valid"]
124+
dump_sets = ["train", "valid"]
121125

122126
for d_set in dump_sets:
123-
full_fol = os.path.join(dump_dir,d_set)
124-
align_path = os.path.join(full_fol,"alignments")
127+
full_fol = os.path.join(dump_dir, d_set)
128+
align_path = os.path.join(full_fol, "alignments")
125129

126-
ids_path = os.path.join(full_fol,"ids")
127-
durations_path = os.path.join(full_fol,"durations")
130+
ids_path = os.path.join(full_fol, "ids")
131+
durations_path = os.path.join(full_fol, "durations")
128132

129133
safemkdir(align_path)
130134

131135
for duration_fn in tqdm(os.listdir(durations_path)):
132136
if not ".npy" in duration_fn:
133-
continue
134-
135-
id_fn = duration_fn.replace("-durations","-ids")
137+
continue
136138

137-
id_path = os.path.join(ids_path,id_fn)
138-
duration_path = os.path.join(durations_path,duration_fn)
139+
id_fn = duration_fn.replace("-durations", "-ids")
140+
141+
id_path = os.path.join(ids_path, id_fn)
142+
duration_path = os.path.join(durations_path, duration_fn)
139143

140144
duration_arr = np.load(duration_path)
141145
id_arr = np.load(id_path)
@@ -145,25 +149,20 @@ def main():
145149
align = duration_to_alignment(duration_arr)
146150

147151
if align.shape[0] != id_true_size:
148-
align, points = rescale_alignment(align,id_true_size)
152+
align, points = rescale_alignment(align, id_true_size)
149153
else:
150154
points = get_pivot_points(align)
151-
152-
if len(points) == 0:
153-
print("WARNING points are empty for",id_fn)
154155

155-
align = create_guided(align,points,args.looseness)
156+
if len(points) == 0:
157+
print("WARNING points are empty for", id_fn)
156158

157-
158-
align_fn = id_fn.replace("-ids","-alignment")
159-
align_full_fn = os.path.join(align_path,align_fn)
160-
161-
np.save(align_full_fn,align.astype("float32"))
162-
159+
align = create_guided(align, points, args.looseness)
163160

161+
align_fn = id_fn.replace("-ids", "-alignment")
162+
align_full_fn = os.path.join(align_path, align_fn)
164163

164+
np.save(align_full_fn, align.astype("float32"))
165165

166166

167167
if __name__ == "__main__":
168168
main()
169-

examples/tacotron2/extract_postnets.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def main():
135135
reduction_factor=config["tacotron2_params"]["reduction_factor"],
136136
use_fixed_shapes=True,
137137
)
138-
dataset = dataset.create(allow_cache=True, batch_size=args.batch_size, drop_remainder=False)
138+
dataset = dataset.create(
139+
allow_cache=True, batch_size=args.batch_size, drop_remainder=False
140+
)
139141

140142
# define model and load checkpoint
141143
tacotron2 = TFTacotron2(
@@ -170,11 +172,11 @@ def main():
170172
alignment_historys = alignment_historys.numpy()
171173
post_mel_outputs = post_mel_outputs.numpy()
172174
mel_gt = mel_gt.numpy()
173-
174-
outdpost = os.path.join(args.outdir,"postnets")
175-
175+
176+
outdpost = os.path.join(args.outdir, "postnets")
177+
176178
if not os.path.exists(outdpost):
177-
os.makedirs(outdpost)
179+
os.makedirs(outdpost)
178180

179181
for i, alignment in enumerate(alignment_historys):
180182
real_char_length = input_lengths[i].numpy()

0 commit comments

Comments
 (0)