@@ -437,11 +437,8 @@ def split_data(data, num_part):
437
437
]
438
438
439
439
440
- def test_context (train_progm , avg_cost , train_exe , dev_count , data_input_names ,
440
+ def test_context (test_program , avg_cost , train_exe , dev_count , data_input_names ,
441
441
sum_cost , token_num ):
442
- # Context to do validation.
443
- test_program = train_progm .clone (for_test = True )
444
-
445
442
val_data = DataReader (
446
443
src_vocab_fpath = TrainTaskConfig .src_vocab_fpath ,
447
444
trg_vocab_fpath = TrainTaskConfig .trg_vocab_fpath ,
@@ -503,7 +500,7 @@ def test(exe=test_exe):
503
500
504
501
505
502
def train_loop (exe , train_progm , dev_count , sum_cost , avg_cost , lr_scheduler ,
506
- token_num , predict ):
503
+ token_num , predict , test_program ):
507
504
# Initialize the parameters.
508
505
if TrainTaskConfig .ckpt_path :
509
506
lr_scheduler .current_steps = TrainTaskConfig .start_step
@@ -552,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
552
549
- 1 ] + label_data_input_fields
553
550
554
551
if TrainTaskConfig .val_file_pattern is not None :
555
- test = test_context (train_progm , avg_cost , train_exe , dev_count ,
552
+ test = test_context (test_program , avg_cost , train_exe , dev_count ,
556
553
data_input_names , sum_cost , token_num )
557
554
558
555
# the best cross-entropy value with label smoothing
@@ -1645,6 +1642,8 @@ def get_model(is_dist, is_async):
1645
1642
local_lr_scheduler = LearningRateScheduler (ModelHyperParams .d_model ,
1646
1643
TrainTaskConfig .warmup_steps ,
1647
1644
TrainTaskConfig .learning_rate )
1645
+ # Context to do validation.
1646
+ test_program = fluid .default_main_program ().clone (for_test = True )
1648
1647
1649
1648
if not is_dist :
1650
1649
optimizer = fluid .optimizer .Adam (
@@ -1669,7 +1668,7 @@ def get_model(is_dist, is_async):
1669
1668
epsilon = TrainTaskConfig .eps )
1670
1669
optimizer .minimize (sum_cost )
1671
1670
1672
- return sum_cost , avg_cost , predict , token_num , local_lr_scheduler
1671
+ return sum_cost , avg_cost , predict , token_num , local_lr_scheduler , test_program
1673
1672
1674
1673
1675
1674
def update_args ():
@@ -1703,7 +1702,7 @@ def run_pserver(self, args):
1703
1702
def run_trainer (self , use_cuda , args ):
1704
1703
place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
1705
1704
TrainTaskConfig .use_gpu = use_cuda
1706
- sum_cost , avg_cost , predict , token_num , local_lr_scheduler = get_model (
1705
+ sum_cost , avg_cost , predict , token_num , local_lr_scheduler , test_program = get_model (
1707
1706
args .is_dist , not args .sync_mode )
1708
1707
1709
1708
if args .is_dist :
@@ -1724,7 +1723,7 @@ def run_trainer(self, use_cuda, args):
1724
1723
TrainTaskConfig .local = not args .is_dist
1725
1724
1726
1725
train_loop (startup_exe , trainer_prog , 1 , sum_cost , avg_cost ,
1727
- local_lr_scheduler , token_num , predict )
1726
+ local_lr_scheduler , token_num , predict , test_program )
1728
1727
1729
1728
1730
1729
if __name__ == "__main__" :
0 commit comments