11#!/usr/bin/env python3
22# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3-
43'''
54 Merge training configs into a single inference config.
65 The single inference config is for CLI, which only takes a single config to do inferencing.
76 The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file.
87'''
9-
10- import yaml
11- import json
12- import os
138import argparse
9+ import json
1410import math
11+ import os
12+ from contextlib import redirect_stdout
13+
1514from yacs .config import CfgNode
1615
1716from paddlespeech .s2t .frontend .utility import load_dict
18- from contextlib import redirect_stdout
1917
2018
2119def save (save_path , config ):
@@ -29,18 +27,21 @@ def load(save_path):
2927 config .merge_from_file (save_path )
3028 return config
3129
30+
3231def load_json (json_path ):
3332 with open (json_path ) as f :
3433 json_content = json .load (f )
3534 return json_content
3635
36+
3737def remove_config_part (config , key_list ):
3838 if len (key_list ) == 0 :
3939 return
40- for i in range (len (key_list ) - 1 ):
40+ for i in range (len (key_list ) - 1 ):
4141 config = config [key_list [i ]]
4242 config .pop (key_list [- 1 ])
4343
44+
4445def load_cmvn_from_json (cmvn_stats ):
4546 means = cmvn_stats ['mean_stat' ]
4647 variance = cmvn_stats ['var_stat' ]
@@ -51,17 +52,17 @@ def load_cmvn_from_json(cmvn_stats):
5152 if variance [i ] < 1.0e-20 :
5253 variance [i ] = 1.0e-20
5354 variance [i ] = 1.0 / math .sqrt (variance [i ])
54- cmvn_stats = {"mean" :means , "istd" :variance }
55+ cmvn_stats = {"mean" : means , "istd" : variance }
5556 return cmvn_stats
5657
58+
5759def merge_configs (
58- conf_path = "conf/conformer.yaml" ,
59- preprocess_path = "conf/preprocess.yaml" ,
60- decode_path = "conf/tuning/decode.yaml" ,
61- vocab_path = "data/vocab.txt" ,
62- cmvn_path = "data/mean_std.json" ,
63- save_path = "conf/conformer_infer.yaml" ,
64- ):
60+ conf_path = "conf/conformer.yaml" ,
61+ preprocess_path = "conf/preprocess.yaml" ,
62+ decode_path = "conf/tuning/decode.yaml" ,
63+ vocab_path = "data/vocab.txt" ,
64+ cmvn_path = "data/mean_std.json" ,
65+ save_path = "conf/conformer_infer.yaml" , ):
6566
6667 # Load the configs
6768 config = load (conf_path )
@@ -72,17 +73,16 @@ def merge_configs(
7273 if cmvn_path .split ("." )[- 1 ] == 'json' :
7374 cmvn_stats = load_json (cmvn_path )
7475 if os .path .exists (preprocess_path ):
75- preprocess_config = load (preprocess_path )
76+ preprocess_config = load (preprocess_path )
7677 for idx , process in enumerate (preprocess_config ["process" ]):
7778 if process ['type' ] == "cmvn_json" :
78- preprocess_config ["process" ][idx ][
79- "cmvn_path" ] = cmvn_stats
79+ preprocess_config ["process" ][idx ]["cmvn_path" ] = cmvn_stats
8080 break
8181
8282 config .preprocess_config = preprocess_config
8383 else :
8484 cmvn_stats = load_cmvn_from_json (cmvn_stats )
85- config .mean_std_filepath = [{"cmvn_stats" :cmvn_stats }]
85+ config .mean_std_filepath = [{"cmvn_stats" : cmvn_stats }]
8686 config .augmentation_config = ''
8787 # the cmvn file is end with .ark
8888 else :
@@ -95,7 +95,8 @@ def merge_configs(
9595 # Remove some parts of the config
9696
9797 if os .path .exists (preprocess_path ):
98- remove_train_list = ["train_manifest" ,
98+ remove_train_list = [
99+ "train_manifest" ,
99100 "dev_manifest" ,
100101 "test_manifest" ,
101102 "n_epoch" ,
@@ -124,9 +125,10 @@ def merge_configs(
124125 "batch_size" ,
125126 "maxlen_in" ,
126127 "maxlen_out" ,
127- ]
128+ ]
128129 else :
129- remove_train_list = ["train_manifest" ,
130+ remove_train_list = [
131+ "train_manifest" ,
130132 "dev_manifest" ,
131133 "test_manifest" ,
132134 "n_epoch" ,
@@ -141,43 +143,41 @@ def merge_configs(
141143 "weight_decay" ,
142144 "sortagrad" ,
143145 "num_workers" ,
144- ]
146+ ]
145147
146148 for item in remove_train_list :
147149 try :
148150 remove_config_part (config , [item ])
149151 except :
150- print ( item + " " + "can not be removed" )
152+ print ( item + " " + "can not be removed" )
151153
152154 # Save the config
153155 save (save_path , config )
154156
155157
156-
157158if __name__ == "__main__" :
158- parser = argparse .ArgumentParser (
159- prog = 'Config merge' , add_help = True )
159+ parser = argparse .ArgumentParser (prog = 'Config merge' , add_help = True )
160160 parser .add_argument (
161- '--cfg_pth' , type = str , default = 'conf/transformer.yaml' , help = 'origin config file' )
161+ '--cfg_pth' ,
162+ type = str ,
163+ default = 'conf/transformer.yaml' ,
164+ help = 'origin config file' )
162165 parser .add_argument (
163- '--pre_pth' , type = str , default = "conf/preprocess.yaml" , help = '' )
166+ '--pre_pth' , type = str , default = "conf/preprocess.yaml" , help = '' )
164167 parser .add_argument (
165- '--dcd_pth' , type = str , default = "conf/tuninig/decode.yaml" , help = '' )
168+ '--dcd_pth' , type = str , default = "conf/tuninig/decode.yaml" , help = '' )
166169 parser .add_argument (
167- '--vb_pth' , type = str , default = "data/lang_char/vocab.txt" , help = '' )
170+ '--vb_pth' , type = str , default = "data/lang_char/vocab.txt" , help = '' )
168171 parser .add_argument (
169- '--cmvn_pth' , type = str , default = "data/mean_std.json" , help = '' )
172+ '--cmvn_pth' , type = str , default = "data/mean_std.json" , help = '' )
170173 parser .add_argument (
171- '--save_pth' , type = str , default = "conf/transformer_infer.yaml" , help = '' )
174+ '--save_pth' , type = str , default = "conf/transformer_infer.yaml" , help = '' )
172175 parser_args = parser .parse_args ()
173176
174177 merge_configs (
175- conf_path = parser_args .cfg_pth ,
176- decode_path = parser_args .dcd_pth ,
177- preprocess_path = parser_args .pre_pth ,
178- vocab_path = parser_args .vb_pth ,
179- cmvn_path = parser_args .cmvn_pth ,
180- save_path = parser_args .save_pth ,
181- )
182-
183-
178+ conf_path = parser_args .cfg_pth ,
179+ decode_path = parser_args .dcd_pth ,
180+ preprocess_path = parser_args .pre_pth ,
181+ vocab_path = parser_args .vb_pth ,
182+ cmvn_path = parser_args .cmvn_pth ,
183+ save_path = parser_args .save_pth , )
0 commit comments