Skip to content

Commit 97965f4

Browse files
committed
fix mlm_prob, test=tts
1 parent c1395e3 commit 97965f4

File tree

7 files changed

+12
-21
lines changed

7 files changed

+12
-21
lines changed

examples/aishell3/ernie_sat/local/train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88
--dev-metadata=dump/dev/norm/metadata.jsonl \
99
--config=${config_path} \
1010
--output-dir=${train_output_path} \
11-
--ngpu=1 \
11+
--ngpu=2 \
1212
--phones-dict=dump/phone_id_map.txt

examples/aishell3_vctk/ernie_sat/local/train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88
--dev-metadata=dump/dev/norm/metadata.jsonl \
99
--config=${config_path} \
1010
--output-dir=${train_output_path} \
11-
--ngpu=1 \
11+
--ngpu=2 \
1212
--phones-dict=dump/phone_id_map.txt

examples/vctk/ernie_sat/conf/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ grad_clip: 1.0
7979
###########################################################
8080
# TRAINING SETTING #
8181
###########################################################
82-
max_epoch: 200
82+
max_epoch: 600
8383
num_snapshots: 5
8484

8585
###########################################################
@@ -160,4 +160,4 @@ token_list:
160160
- UH0
161161
- AW0
162162
- OY0
163-
- <sos/eos>
163+
- <sos/eos>

examples/vctk/ernie_sat/local/train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88
--dev-metadata=dump/dev/norm/metadata.jsonl \
99
--config=${config_path} \
1010
--output-dir=${train_output_path} \
11-
--ngpu=1 \
11+
--ngpu=2 \
1212
--phones-dict=dump/phone_id_map.txt

paddlespeech/t2s/datasets/am_batch_fn.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,13 @@
2929

3030

3131
# 因为要传参数,所以需要额外构建
32-
def build_erniesat_collate_fn(
33-
mlm_prob: float=0.8,
34-
mean_phn_span: int=8,
35-
seg_emb: bool=False,
36-
text_masking: bool=False,
37-
epoch: int=-1, ):
38-
39-
if epoch == -1:
40-
mlm_prob_factor = 1
41-
else:
42-
mlm_prob_factor = 0.8
32+
def build_erniesat_collate_fn(mlm_prob: float=0.8,
33+
mean_phn_span: int=8,
34+
seg_emb: bool=False,
35+
text_masking: bool=False):
4336

4437
return ErnieSATCollateFn(
45-
mlm_prob=mlm_prob * mlm_prob_factor,
38+
mlm_prob=mlm_prob,
4639
mean_phn_span=mean_phn_span,
4740
seg_emb=seg_emb,
4841
text_masking=text_masking)

paddlespeech/t2s/exps/ernie_sat/synthesize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ def evaluate(args):
7373
mlm_prob=erniesat_config.mlm_prob,
7474
mean_phn_span=erniesat_config.mean_phn_span,
7575
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
76-
text_masking=False,
77-
epoch=-1)
76+
text_masking=False)
7877

7978
gen_raw = True
8079
erniesat_mu, erniesat_std = np.load(args.erniesat_stat)

paddlespeech/t2s/exps/ernie_sat/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ def train_sp(args, config):
8484
mlm_prob=config.mlm_prob,
8585
mean_phn_span=config.mean_phn_span,
8686
seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
87-
text_masking=config["model"]["text_masking"],
88-
epoch=config["max_epoch"])
87+
text_masking=config["model"]["text_masking"])
8988

9089
train_sampler = DistributedBatchSampler(
9190
train_dataset,

0 commit comments

Comments
 (0)