2424
2525from utils .static_ps .reader_helper import get_reader
2626from utils .utils_single import load_yaml , load_static_model_class , get_abs_model , create_data_loader , reset_auc
27- from utils .save_load import save_static_model
27+ from utils .save_load import save_static_model , save_inference_model
2828
2929import time
3030import argparse
3636def parse_args ():
3737 parser = argparse .ArgumentParser ("PaddleRec train static script" )
3838 parser .add_argument ("-m" , "--config_yaml" , type = str )
39+ parser .add_argument ("-o" , "--opt" , nargs = '*' , type = str )
3940 args = parser .parse_args ()
4041 args .abs_dir = os .path .dirname (os .path .abspath (args .config_yaml ))
4142 args .config_yaml = get_abs_model (args .config_yaml )
@@ -49,6 +50,12 @@ def main(args):
4950 config = load_yaml (args .config_yaml )
5051 config ["yaml_path" ] = args .config_yaml
5152 config ["config_abs_dir" ] = args .abs_dir
53+ # modify config from command
54+ if args .opt :
55+ for parameter in args .opt :
56+ parameter = parameter .strip ()
57+ key , value = parameter .split ("=" )
58+ config [key ] = value
5259 # load static model class
5360 static_model_class = load_static_model_class (config )
5461
@@ -63,6 +70,7 @@ def main(args):
6370 use_gpu = config .get ("runner.use_gpu" , True )
6471 use_auc = config .get ("runner.use_auc" , False )
6572 use_visual = config .get ("runner.use_visual" , False )
73+ use_inference = config .get ("runner.use_inference" , False )
6674 auc_num = config .get ("runner.auc_num" , 1 )
6775 train_data_dir = config .get ("runner.train_data_dir" , None )
6876 epochs = config .get ("runner.epochs" , None )
@@ -74,9 +82,9 @@ def main(args):
7482 os .environ ["CPU_NUM" ] = str (config .get ("runner.thread_num" , 1 ))
7583 logger .info ("**************common.configs**********" )
7684 logger .info (
77- "use_gpu: {}, use_visual: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}" .
78- format (use_gpu , use_visual , train_data_dir , epochs , print_interval ,
79- model_save_path ))
85+ "use_gpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}" .
86+ format (use_gpu , use_visual , batch_size , train_data_dir , epochs ,
87+ print_interval , model_save_path ))
8088 logger .info ("**************common.configs**********" )
8189
8290 place = paddle .set_device ('gpu' if use_gpu else 'cpu' )
@@ -124,11 +132,44 @@ def main(args):
124132 else :
125133 logger .info ("reader type wrong" )
126134
127- save_static_model (
128- paddle .static .default_main_program (),
129- model_save_path ,
130- epoch_id ,
131- prefix = 'rec_static' )
135+ if use_inference :
136+ feed_var_names = config .get ("runner.save_inference_feed_varnames" ,
137+ [])
138+ feedvars = []
139+ fetch_var_names = config .get (
140+ "runner.save_inference_fetch_varnames" , [])
141+ fetchvars = []
142+ for var_name in feed_var_names :
143+ if var_name not in paddle .static .default_main_program (
144+ ).global_block ().vars :
145+ raise ValueError (
146+ "Feed variable: {} not in default_main_program, global block has follow vars: {}" .
147+ format (var_name ,
148+ paddle .static .default_main_program ()
149+ .global_block ().vars .keys ()))
150+ else :
151+ feedvars .append (paddle .static .default_main_program ()
152+ .global_block ().vars [var_name ])
153+ for var_name in fetch_var_names :
154+ if var_name not in paddle .static .default_main_program (
155+ ).global_block ().vars :
156+ raise ValueError (
157+ "Fetch variable: {} not in default_main_program, global block has follow vars: {}" .
158+ format (var_name ,
159+ paddle .static .default_main_program ()
160+ .global_block ().vars .keys ()))
161+ else :
162+ fetchvars .append (paddle .static .default_main_program ()
163+ .global_block ().vars [var_name ])
164+
165+ save_inference_model (model_save_path , epoch_id , feedvars ,
166+ fetchvars , exe )
167+ else :
168+ save_static_model (
169+ paddle .static .default_main_program (),
170+ model_save_path ,
171+ epoch_id ,
172+ prefix = 'rec_static' )
132173
133174
134175def dataset_train (epoch_id , dataset , fetch_vars , exe , config ):
@@ -179,7 +220,7 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
179220 logger .info (
180221 "epoch: {}, batch_id: {}, " .format (epoch_id ,
181222 batch_id ) + metric_str +
182- "avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec " .
223+ "avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} ins/s " .
183224 format (train_reader_cost / print_interval , (
184225 train_reader_cost + train_run_cost ) / print_interval ,
185226 total_samples / print_interval , total_samples / (
0 commit comments