@@ -437,13 +437,8 @@ def split_data(data, num_part):
437
437
]
438
438
439
439
440
- def test_context (train_progm , avg_cost , train_exe , dev_count , data_input_names ,
440
+ def test_context (test_program , avg_cost , train_exe , dev_count , data_input_names ,
441
441
sum_cost , token_num ):
442
- # Context to do validation.
443
- test_program = train_progm .clone ()
444
- with fluid .program_guard (test_program ):
445
- test_program = fluid .io .get_inference_program ([avg_cost ])
446
-
447
442
val_data = DataReader (
448
443
src_vocab_fpath = TrainTaskConfig .src_vocab_fpath ,
449
444
trg_vocab_fpath = TrainTaskConfig .trg_vocab_fpath ,
@@ -505,7 +500,7 @@ def test(exe=test_exe):
505
500
506
501
507
502
def train_loop (exe , train_progm , dev_count , sum_cost , avg_cost , lr_scheduler ,
508
- token_num , predict ):
503
+ token_num , predict , test_program ):
509
504
# Initialize the parameters.
510
505
if TrainTaskConfig .ckpt_path :
511
506
lr_scheduler .current_steps = TrainTaskConfig .start_step
@@ -554,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
554
549
- 1 ] + label_data_input_fields
555
550
556
551
if TrainTaskConfig .val_file_pattern is not None :
557
- test = test_context (train_progm , avg_cost , train_exe , dev_count ,
552
+ test = test_context (test_program , avg_cost , train_exe , dev_count ,
558
553
data_input_names , sum_cost , token_num )
559
554
560
555
# the best cross-entropy value with label smoothing
@@ -1647,6 +1642,8 @@ def get_model(is_dist, is_async):
1647
1642
local_lr_scheduler = LearningRateScheduler (ModelHyperParams .d_model ,
1648
1643
TrainTaskConfig .warmup_steps ,
1649
1644
TrainTaskConfig .learning_rate )
1645
+ # Context to do validation.
1646
+ test_program = fluid .default_main_program ().clone (for_test = True )
1650
1647
1651
1648
if not is_dist :
1652
1649
optimizer = fluid .optimizer .Adam (
@@ -1671,7 +1668,7 @@ def get_model(is_dist, is_async):
1671
1668
epsilon = TrainTaskConfig .eps )
1672
1669
optimizer .minimize (sum_cost )
1673
1670
1674
- return sum_cost , avg_cost , predict , token_num , local_lr_scheduler
1671
+ return sum_cost , avg_cost , predict , token_num , local_lr_scheduler , test_program
1675
1672
1676
1673
1677
1674
def update_args ():
@@ -1705,7 +1702,7 @@ def run_pserver(self, args):
1705
1702
def run_trainer (self , use_cuda , args ):
1706
1703
place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
1707
1704
TrainTaskConfig .use_gpu = use_cuda
1708
- sum_cost , avg_cost , predict , token_num , local_lr_scheduler = get_model (
1705
+ sum_cost , avg_cost , predict , token_num , local_lr_scheduler , test_program = get_model (
1709
1706
args .is_dist , not args .sync_mode )
1710
1707
1711
1708
if args .is_dist :
@@ -1726,7 +1723,7 @@ def run_trainer(self, use_cuda, args):
1726
1723
TrainTaskConfig .local = not args .is_dist
1727
1724
1728
1725
train_loop (startup_exe , trainer_prog , 1 , sum_cost , avg_cost ,
1729
- local_lr_scheduler , token_num , predict )
1726
+ local_lr_scheduler , token_num , predict , test_program )
1730
1727
1731
1728
1732
1729
if __name__ == "__main__" :
0 commit comments