1919import warnings
2020import numpy as np
2121import random
22+ import json
2223import logging
2324import paddle .fluid as fluid
2425
@@ -147,17 +148,22 @@ def _executor_dataloader_train(self, model_dict, context):
147148 metrics_format = []
148149
149150 if context ["is_infer" ]:
150- metrics_format .append ("\t [Infer]\t {}: {{}}" .format ("batch" ))
151+ metrics_format .append ("\t [Infer] {}: {{}}" .format ("batch" ))
151152 else :
152- metrics_format .append ("\t [Train]\t {}: {{}}" .format ("batch" ))
153+ metrics_format .append ("\t [Train]" )
154+ if "current_epoch" in context :
155+ metrics_format .append (" epoch: {}" .format (context [
156+ "current_epoch" ]))
157+ metrics_format .append (" {}: {{}}" .format ("batch" ))
153158
154159 metrics_format .append ("{}: {{:.2f}}s" .format ("time_each_interval" ))
155160
156161 metrics_names = ["total_batch" ]
157-
162+ metrics_indexes = dict ()
158163 for name , var in metrics .items ():
159164 metrics_names .append (name )
160165 metrics_varnames .append (var .name )
166+ metrics_indexes [var .name ] = len (metrics_varnames ) - 1
161167 metrics_format .append ("{}: {{}}" .format (name ))
162168 metrics_format = ", " .join (metrics_format )
163169
@@ -166,6 +172,7 @@ def _executor_dataloader_train(self, model_dict, context):
166172 batch_id = 0
167173 begin_time = time .time ()
168174 scope = context ["model" ][model_name ]["scope" ]
175+ runner_results = []
169176 result = None
170177 with fluid .scope_guard (scope ):
171178 try :
@@ -182,18 +189,35 @@ def _executor_dataloader_train(self, model_dict, context):
182189 ]
183190 metrics .extend (metrics_rets )
184191
192+ batch_runner_result = {}
193+ for k , v in metrics_indexes .items ():
194+ batch_runner_result [k ] = np .array (metrics_rets [
195+ v ]).tolist ()
196+ runner_results .append (batch_runner_result )
197+
185198 if batch_id % fetch_period == 0 and batch_id != 0 :
186199 end_time = time .time ()
187200 seconds = end_time - begin_time
188201 metrics_logging = metrics [:]
189202 metrics_logging = metrics .insert (1 , seconds )
190203 begin_time = end_time
191-
192204 logging .info (metrics_format .format (* metrics ))
193205 batch_id += 1
194206 except fluid .core .EOFException :
195207 reader .reset ()
196208
209+ runner_result_save_path = envs .get_global_env (
210+ "runner." + context ["runner_name" ] + ".runner_result_dump_path" ,
211+ None )
212+ if runner_result_save_path :
213+ if "current_epoch" in context :
214+ runner_result_save_path = runner_result_save_path + "_epoch_{}" .format (
215+ context ["current_epoch" ])
216+ logging .info ("Dump runner result in {}" .format (
217+ runner_result_save_path ))
218+ with open (runner_result_save_path , 'w+' ) as fout :
219+ json .dump (runner_results , fout )
220+
197221 if batch_id > 0 :
198222 result = dict (zip (metrics_names , metrics ))
199223 return result
@@ -402,6 +426,7 @@ def run(self, context):
402426 filelist = context ["file_list" ]
403427 context ["file_list" ] = shuffle_files (need_shuffle_files ,
404428 filelist )
429+ context ["current_epoch" ] = epoch
405430 begin_time = time .time ()
406431 result = self ._run (context , model_dict )
407432 end_time = time .time ()
@@ -450,6 +475,7 @@ def run(self, context):
450475 filelist = context ["file_list" ]
451476 context ["file_list" ] = shuffle_files (need_shuffle_files ,
452477 filelist )
478+ context ["current_epoch" ] = epoch
453479 begin_time = time .time ()
454480 result = self ._run (context , model_dict )
455481 end_time = time .time ()
@@ -500,6 +526,7 @@ def run(self, context):
500526 filelist = context ["file_list" ]
501527 context ["file_list" ] = shuffle_files (need_shuffle_files ,
502528 filelist )
529+ context ["current_epoch" ] = epoch
503530 begin_time = time .time ()
504531 self ._run (context , model_dict )
505532 end_time = time .time ()
@@ -533,6 +560,7 @@ def run(self, context):
533560 filelist = context ["file_list" ]
534561 context ["file_list" ] = shuffle_files (need_shuffle_files ,
535562 filelist )
563+ context ["current_epoch" ] = epoch
536564 begin_time = time .time ()
537565 self ._run (context , model_dict )
538566 end_time = time .time ()
0 commit comments