@@ -62,6 +62,7 @@ def main(args):
6262
6363 use_gpu = config .get ("runner.use_gpu" , True )
6464 use_auc = config .get ("runner.use_auc" , False )
65+ use_visual = config .get ("runner.use_visual" , False )
6566 auc_num = config .get ("runner.auc_num" , 1 )
6667 train_data_dir = config .get ("runner.train_data_dir" , None )
6768 epochs = config .get ("runner.epochs" , None )
@@ -73,8 +74,8 @@ def main(args):
7374 os .environ ["CPU_NUM" ] = str (config .get ("runner.thread_num" , 1 ))
7475 logger .info ("**************common.configs**********" )
7576 logger .info (
76- "use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}" .
77- format (use_gpu , train_data_dir , epochs , print_interval ,
77+ "use_gpu: {}, use_visual: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}" .
78+ format (use_gpu , use_visual , train_data_dir , epochs , print_interval ,
7879 model_save_path ))
7980 logger .info ("**************common.configs**********" )
8081
@@ -85,6 +86,14 @@ def main(args):
8586
8687 last_epoch_id = config .get ("last_epoch" , - 1 )
8788
89+ # Create a log_visual object and store the data in the path
90+ if use_visual :
91+ from visualdl import LogWriter
92+ log_visual = LogWriter (args .abs_dir + "/visualDL_log/train" )
93+ else :
94+ log_visual = None
95+ step_num = 0
96+
8897 if reader_type == 'QueueDataset' :
8998 dataset , file_list = get_reader (input_data , config )
9099 elif reader_type == 'DataLoader' :
@@ -96,9 +105,9 @@ def main(args):
96105 if use_auc :
97106 reset_auc (auc_num )
98107 if reader_type == 'DataLoader' :
99- fetch_batch_var = dataloader_train (epoch_id , train_dataloader ,
100- input_data_names , fetch_vars ,
101- exe , config )
108+ fetch_batch_var , step_num = dataloader_train (
109+ epoch_id , train_dataloader , input_data_names , fetch_vars , exe ,
110+ config , use_visual , log_visual , step_num )
102111 metric_str = ""
103112 for var_idx , var_name in enumerate (fetch_vars ):
104113 metric_str += "{}: {}, " .format (var_name ,
@@ -139,7 +148,7 @@ def dataset_train(epoch_id, dataset, fetch_vars, exe, config):
139148
140149
141150def dataloader_train (epoch_id , train_dataloader , input_data_names , fetch_vars ,
142- exe , config ):
151+ exe , config , use_visual , log_visual , step_num ):
143152 print_interval = config .get ("runner.print_interval" , None )
144153 batch_size = config .get ("runner.train_batch_size" , None )
145154 interval_begin = time .time ()
@@ -162,6 +171,11 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
162171 for var_idx , var_name in enumerate (fetch_vars ):
163172 metric_str += "{}: {}, " .format (var_name ,
164173 fetch_batch_var [var_idx ])
174+ if use_visual :
175+ log_visual .add_scalar (
176+ tag = "train/" + var_name ,
177+ step = step_num ,
178+ value = fetch_batch_var [var_idx ])
165179 logger .info (
166180 "epoch: {}, batch_id: {}, " .format (epoch_id ,
167181 batch_id ) + metric_str +
@@ -174,7 +188,8 @@ def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
174188 train_run_cost = 0.0
175189 total_samples = 0
176190 reader_start = time .time ()
177- return fetch_batch_var
191+ step_num = step_num + 1
192+ return fetch_batch_var , step_num
178193
179194
180195if __name__ == "__main__" :
0 commit comments