File tree Expand file tree Collapse file tree 7 files changed +12
-21
lines changed
aishell3_vctk/ernie_sat/local Expand file tree Collapse file tree 7 files changed +12
-21
lines changed Original file line number Diff line number Diff line change @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88 --dev-metadata=dump/dev/norm/metadata.jsonl \
99 --config=${config_path} \
1010 --output-dir=${train_output_path} \
11- --ngpu=1 \
11+ --ngpu=2 \
1212 --phones-dict=dump/phone_id_map.txt
Original file line number Diff line number Diff line change @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88 --dev-metadata=dump/dev/norm/metadata.jsonl \
99 --config=${config_path} \
1010 --output-dir=${train_output_path} \
11- --ngpu=1 \
11+ --ngpu=2 \
1212 --phones-dict=dump/phone_id_map.txt
Original file line number Diff line number Diff line change @@ -79,7 +79,7 @@ grad_clip: 1.0
7979# ##########################################################
8080# TRAINING SETTING #
8181# ##########################################################
82- max_epoch : 200
82+ max_epoch : 600
8383num_snapshots : 5
8484
8585# ##########################################################
@@ -160,4 +160,4 @@ token_list:
160160- UH0
161161- AW0
162162- OY0
163- - <sos/eos>
163+ - <sos/eos>
Original file line number Diff line number Diff line change @@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88 --dev-metadata=dump/dev/norm/metadata.jsonl \
99 --config=${config_path} \
1010 --output-dir=${train_output_path} \
11- --ngpu=1 \
11+ --ngpu=2 \
1212 --phones-dict=dump/phone_id_map.txt
Original file line number Diff line number Diff line change 2929
3030
3131# 因为要传参数,所以需要额外构建
32- def build_erniesat_collate_fn (
33- mlm_prob : float = 0.8 ,
34- mean_phn_span : int = 8 ,
35- seg_emb : bool = False ,
36- text_masking : bool = False ,
37- epoch : int = - 1 , ):
38-
39- if epoch == - 1 :
40- mlm_prob_factor = 1
41- else :
42- mlm_prob_factor = 0.8
32+ def build_erniesat_collate_fn (mlm_prob : float = 0.8 ,
33+ mean_phn_span : int = 8 ,
34+ seg_emb : bool = False ,
35+ text_masking : bool = False ):
4336
4437 return ErnieSATCollateFn (
45- mlm_prob = mlm_prob * mlm_prob_factor ,
38+ mlm_prob = mlm_prob ,
4639 mean_phn_span = mean_phn_span ,
4740 seg_emb = seg_emb ,
4841 text_masking = text_masking )
Original file line number Diff line number Diff line change @@ -73,8 +73,7 @@ def evaluate(args):
7373 mlm_prob = erniesat_config .mlm_prob ,
7474 mean_phn_span = erniesat_config .mean_phn_span ,
7575 seg_emb = erniesat_config .model ['enc_input_layer' ] == 'sega_mlm' ,
76- text_masking = False ,
77- epoch = - 1 )
76+ text_masking = False )
7877
7978 gen_raw = True
8079 erniesat_mu , erniesat_std = np .load (args .erniesat_stat )
Original file line number Diff line number Diff line change @@ -84,8 +84,7 @@ def train_sp(args, config):
8484 mlm_prob = config .mlm_prob ,
8585 mean_phn_span = config .mean_phn_span ,
8686 seg_emb = config .model ['enc_input_layer' ] == 'sega_mlm' ,
87- text_masking = config ["model" ]["text_masking" ],
88- epoch = config ["max_epoch" ])
87+ text_masking = config ["model" ]["text_masking" ])
8988
9089 train_sampler = DistributedBatchSampler (
9190 train_dataset ,
You can’t perform that action at this time.
0 commit comments