Skip to content

Commit 8ef1f9f

Browse files
committed
Polish code
1 parent 07f495e commit 8ef1f9f

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ def event_handler(event):
107107
event_handler=event_handler,
108108
feed_order=['pixel', 'label'])
109109

110-
return trainer
110+
def _del_trainer(trainer):
111+
del trainer
112+
113+
if six.PY3:
114+
_del_trainer(trainer)
111115

112116

113117
def infer(use_cuda, inference_program, parallel, params_dirname=None):
@@ -131,15 +135,12 @@ def main(use_cuda, parallel):
131135
save_path = "image_classification_vgg.inference.model"
132136

133137
os.environ['CPU_NUM'] = str(4)
134-
trainer = train(
138+
train(
135139
use_cuda=use_cuda,
136140
train_program=train_network,
137141
params_dirname=save_path,
138142
parallel=parallel)
139143

140-
if six.PY3:
141-
del trainer
142-
143144
# FIXME(zcd): in the inference stage, the number of
144145
# input data is one, it is not appropriate to use parallel.
145146
if parallel and use_cuda:

python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def event_handler(event):
9090
reader=train_reader,
9191
feed_order=['img', 'label'])
9292

93-
return trainer
93+
def _del_trainer(trainer):
94+
del trainer
95+
96+
if six.PY3:
97+
_del_trainer(trainer)
9498

9599

96100
def infer(use_cuda, inference_program, parallel, params_dirname=None):
@@ -116,15 +120,12 @@ def main(use_cuda, parallel):
116120

117121
# call train() with is_local argument to run distributed train
118122
os.environ['CPU_NUM'] = str(4)
119-
trainer = train(
123+
train(
120124
use_cuda=use_cuda,
121125
train_program=train_program,
122126
params_dirname=params_dirname,
123127
parallel=parallel)
124128

125-
if six.PY3:
126-
del trainer
127-
128129
# FIXME(zcd): in the inference stage, the number of
129130
# input data is one, it is not appropriate to use parallel.
130131
if parallel and use_cuda:

0 commit comments

Comments
 (0)