Skip to content

Commit 44ba694

Browse files
committed
put clone(for_test=True) before optimization phase
1 parent abf019f commit 44ba694

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

python/paddle/fluid/tests/unittests/dist_transformer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,8 @@ def split_data(data, num_part):
437437
]
438438

439439

440-
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
440+
def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names,
441441
sum_cost, token_num):
442-
# Context to do validation.
443-
test_program = train_progm.clone(for_test=True)
444-
445442
val_data = DataReader(
446443
src_vocab_fpath=TrainTaskConfig.src_vocab_fpath,
447444
trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath,
@@ -503,7 +500,7 @@ def test(exe=test_exe):
503500

504501

505502
def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
506-
token_num, predict):
503+
token_num, predict, test_program):
507504
# Initialize the parameters.
508505
if TrainTaskConfig.ckpt_path:
509506
lr_scheduler.current_steps = TrainTaskConfig.start_step
@@ -552,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
552549
-1] + label_data_input_fields
553550

554551
if TrainTaskConfig.val_file_pattern is not None:
555-
test = test_context(train_progm, avg_cost, train_exe, dev_count,
552+
test = test_context(test_program, avg_cost, train_exe, dev_count,
556553
data_input_names, sum_cost, token_num)
557554

558555
# the best cross-entropy value with label smoothing
@@ -1645,6 +1642,8 @@ def get_model(is_dist, is_async):
16451642
local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
16461643
TrainTaskConfig.warmup_steps,
16471644
TrainTaskConfig.learning_rate)
1645+
# Context to do validation.
1646+
test_program = fluid.default_main_program().clone(for_test=True)
16481647

16491648
if not is_dist:
16501649
optimizer = fluid.optimizer.Adam(
@@ -1669,7 +1668,7 @@ def get_model(is_dist, is_async):
16691668
epsilon=TrainTaskConfig.eps)
16701669
optimizer.minimize(sum_cost)
16711670

1672-
return sum_cost, avg_cost, predict, token_num, local_lr_scheduler
1671+
return sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program
16731672

16741673

16751674
def update_args():
@@ -1703,7 +1702,7 @@ def run_pserver(self, args):
17031702
def run_trainer(self, use_cuda, args):
17041703
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
17051704
TrainTaskConfig.use_gpu = use_cuda
1706-
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
1705+
sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
17071706
args.is_dist, not args.sync_mode)
17081707

17091708
if args.is_dist:
@@ -1724,7 +1723,7 @@ def run_trainer(self, use_cuda, args):
17241723
TrainTaskConfig.local = not args.is_dist
17251724

17261725
train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost,
1727-
local_lr_scheduler, token_num, predict)
1726+
local_lr_scheduler, token_num, predict, test_program)
17281727

17291728

17301729
if __name__ == "__main__":

0 commit comments

Comments
 (0)