Skip to content

Commit 2bb0ac9

Browse files
committed
Polish code
1 parent a04d998 commit 2bb0ac9

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def event_handler(event):
107107
event_handler=event_handler,
108108
feed_order=['pixel', 'label'])
109109

110-
if six.PY3:
111-
del trainer
110+
return trainer
112111

113112

114113
def infer(use_cuda, inference_program, parallel, params_dirname=None):
@@ -132,12 +131,15 @@ def main(use_cuda, parallel):
132131
save_path = "image_classification_vgg.inference.model"
133132

134133
os.environ['CPU_NUM'] = str(4)
135-
train(
134+
trainer = train(
136135
use_cuda=use_cuda,
137136
train_program=train_network,
138137
params_dirname=save_path,
139138
parallel=parallel)
140139

140+
if six.PY3:
141+
del trainer
142+
141143
# FIXME(zcd): in the inference stage, the number of
142144
# input data is one, it is not appropriate to use parallel.
143145
if parallel and use_cuda:

python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ def event_handler(event):
9090
reader=train_reader,
9191
feed_order=['img', 'label'])
9292

93-
if six.PY3:
94-
del trainer
93+
return trainer
9594

9695

9796
def infer(use_cuda, inference_program, parallel, params_dirname=None):
@@ -117,12 +116,15 @@ def main(use_cuda, parallel):
117116

118117
# call train() with is_local argument to run distributed train
119118
os.environ['CPU_NUM'] = str(4)
120-
train(
119+
trainer = train(
121120
use_cuda=use_cuda,
122121
train_program=train_program,
123122
params_dirname=params_dirname,
124123
parallel=parallel)
125124

125+
if six.PY3:
126+
del trainer
127+
126128
# FIXME(zcd): in the inference stage, the number of
127129
# input data is one, it is not appropriate to use parallel.
128130
if parallel and use_cuda:

0 commit comments

Comments
 (0)