Skip to content

Commit ba57348

Browse files
JiayiFengjetfuel
authored andcommitted
trainer.test() (#10453)
* a draft of trainer.test() * polish trainer.test() * polish trainer.test() * update code format * update * polish code * polish code * polish code * Make trainer.test follow the rule of returning [loss, metric, metric, ..]
1 parent f3ffec2 commit ba57348

File tree

4 files changed

+104
-27
lines changed

4 files changed

+104
-27
lines changed

python/paddle/fluid/framework.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self,
160160
persistable=None,
161161
error_clip=None,
162162
stop_gradient=False,
163+
is_data=False,
163164
**kwargs):
164165
self.block = block
165166
self.error_clip = error_clip
@@ -238,6 +239,7 @@ def __init__(self,
238239
self.block.vars[name] = self
239240
self.op = None
240241
self.stop_gradient = stop_gradient
242+
self.is_data = is_data
241243

242244
def __str__(self):
243245
return self.to_string(True)
@@ -475,7 +477,7 @@ def find_name(var_list, name):
475477
if isinstance(attrs[attr_name], Block):
476478
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
477479
elif isinstance(attrs[attr_name], core.BlockDesc) or \
478-
isinstance(attrs[attr_name], core.ProgramDesc):
480+
isinstance(attrs[attr_name], core.ProgramDesc):
479481
self.desc.set_serialized_attr(
480482
attr_name, attrs[attr_name].serialize_to_string())
481483
else:
@@ -978,15 +980,17 @@ def clone_variable(self, var):
978980
shape=var.shape,
979981
dtype=var.dtype,
980982
type=var.type,
981-
persistable=True)
983+
persistable=True,
984+
is_data=var.is_data)
982985
else:
983986
ret_var = self.create_var(
984987
name=var.name,
985988
shape=var.shape,
986989
dtype=var.dtype,
987990
type=var.type,
988991
lod_level=var.lod_level,
989-
persistable=True)
992+
persistable=True,
993+
is_data=var.is_data)
990994
return ret_var
991995

992996

@@ -1051,6 +1055,7 @@ def clone(self, for_test=False):
10511055
p.sync_with_cpp()
10521056

10531057
p.copy_param_info_from(self)
1058+
p.copy_data_info_from(self)
10541059
return p
10551060

10561061
def prune(self, targets):
@@ -1172,6 +1177,26 @@ def copy_param_info_from(self, other):
11721177
"program, with represent the same topology")
11731178
self.global_block().copy_param_info_from(other.global_block())
11741179

1180+
def copy_data_info_from(self, other):
1181+
"""
1182+
Copy the information of data variables from other program.
1183+
Args:
1184+
other(Program): Other program
1185+
1186+
Returns:
1187+
None
1188+
"""
1189+
if not isinstance(other, Program):
1190+
raise TypeError("copy_param_info_from should be invoked with "
1191+
"Program")
1192+
1193+
if len(self.blocks) != len(other.blocks):
1194+
raise ValueError("copy_param_info_from should be invoked with two "
1195+
"program, with represent the same topology")
1196+
for var in other.global_block().vars.itervalues():
1197+
if var.is_data:
1198+
self.global_block().var(var.name).is_data = True
1199+
11751200
def list_vars(self):
11761201
for each_block in self.blocks:
11771202
for each_var in each_block.vars.itervalues():

python/paddle/fluid/layers/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def data(name,
7878
dtype=dtype,
7979
type=type,
8080
stop_gradient=stop_gradient,
81-
lod_level=lod_level)
82-
data_var.is_data = True
81+
lod_level=lod_level,
82+
is_data=True)
8383
return data_var
8484

8585

python/paddle/fluid/tests/book/high-level-api/word2vec/no_test_word2vec_new_api.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ def inference_program(is_sparse):
8080

8181

8282
def train_program(is_sparse):
83-
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
83+
# The declaration of 'next_word' must be after the invoking of inference_program,
84+
# or the data input order of train program would be [next_word, firstw, secondw,
85+
# thirdw, forthw], which is not correct.
8486
predict_word = inference_program(is_sparse)
87+
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
8588
cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
8689
avg_cost = fluid.layers.mean(cost)
8790
return avg_cost
@@ -90,14 +93,17 @@ def train_program(is_sparse):
9093
def train(use_cuda, is_sparse, save_path):
9194
train_reader = paddle.batch(
9295
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
96+
test_reader = paddle.batch(
97+
paddle.dataset.imikolov.test(word_dict, N), BATCH_SIZE)
9398

9499
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
95100

96101
def event_handler(event):
97-
print type(event)
102+
# print type(event)
98103
if isinstance(event, fluid.EndEpochEvent):
99-
avg_cost = trainer.test(reader=paddle.dataset.imikolov.test(
100-
word_dict, N))
104+
outs = trainer.test(reader=test_reader)
105+
avg_cost = outs[0]
106+
print("loss= ", avg_cost)
101107

102108
if avg_cost < 5.0:
103109
trainer.save_params(save_path)

python/paddle/fluid/trainer.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,15 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
7575
self.train_program = framework.Program()
7676

7777
with framework.program_guard(self.train_program, self.startup_program):
78-
loss = program_func()
78+
program_func_outs = program_func()
79+
self.test_outputs = program_func_outs if isinstance(
80+
program_func_outs, list) else [program_func_outs]
81+
self.test_program = self.train_program.clone()
7982
if not isinstance(optimizer, opt_module.Optimizer):
8083
raise TypeError(
8184
"The optimizer should be an instance of Optimizer")
82-
85+
# The fisrt element of program_func_outs is loss.
86+
loss = self.test_outputs[0]
8387
optimize_ops, params_grads = optimizer.minimize(loss)
8488

8589
self.place = Trainer._check_and_get_place(place)
@@ -168,8 +172,17 @@ def train(self,
168172

169173
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
170174

171-
def test(self, reader):
172-
pass
175+
def test(self, reader, feed_order=None):
176+
"""
177+
Test the model on given test data
178+
179+
Args:
180+
reader: The reader that yields test data.
181+
feed_order: Feeding order of reader. None will following the defining
182+
order in program
183+
"""
184+
185+
return self._test_by_executor(reader, feed_order, self.test_outputs)
173186

174187
def save_params(self, param_path):
175188
# reference: save_persistables in io.py
@@ -225,26 +238,59 @@ def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
225238
226239
"""
227240
with self._prog_and_scope_guard():
228-
exe = executor.Executor(self.place)
229-
if feed_order is None:
230-
feed_var_list = [
231-
var
232-
for var in self.train_program.global_block(
233-
).vars.itervalues()
234-
if hasattr(var, 'is_data') and var.is_data
235-
]
236-
else:
237-
feed_var_list = [
238-
self.train_program.global_block().var(var_name)
239-
for var_name in feed_order
240-
]
241-
241+
feed_var_list = build_feed_var_list(self.train_program, feed_order)
242242
feeder = data_feeder.DataFeeder(
243243
feed_list=feed_var_list, place=self.place)
244+
exe = executor.Executor(self.place)
244245
for epoch_id in range(num_epochs):
245246
event_handler(BeginEpochEvent(epoch_id))
246247
for step_id, data in enumerate(reader()):
247248
event_handler(BeginStepEvent(epoch_id, step_id))
248249
exe.run(feed=feeder.feed(data), fetch_list=[])
249250
event_handler(EndStepEvent(epoch_id, step_id))
250251
event_handler(EndEpochEvent(epoch_id))
252+
253+
def _test_by_executor(self, reader, feed_order, fetch_list):
254+
with executor.scope_guard(self.scope):
255+
feed_var_list = build_feed_var_list(self.test_program, feed_order)
256+
feeder = data_feeder.DataFeeder(
257+
feed_list=feed_var_list, place=self.place)
258+
exe = executor.Executor(self.place)
259+
accumulated = len(fetch_list) * [0]
260+
count = 0
261+
for data in reader():
262+
outs = exe.run(program=self.test_program,
263+
feed=feeder.feed(data),
264+
fetch_list=fetch_list)
265+
accumulated = [x[0] + x[1][0] for x in zip(accumulated, outs)]
266+
count += 1
267+
268+
return [x / count for x in accumulated]
269+
270+
271+
def build_feed_var_list(program, feed_order):
272+
if not isinstance(program, framework.Program):
273+
raise TypeError("The 'program' should be an object of Program")
274+
275+
if feed_order is None:
276+
feed_var_list = [
277+
var for var in program.global_block().vars.itervalues()
278+
if var.is_data
279+
]
280+
elif isinstance(feed_order, list):
281+
feed_var_list = [
282+
program.global_block().var(var_name) for var_name in feed_order
283+
]
284+
else:
285+
if not isinstance(feed_order, dict):
286+
raise TypeError(
287+
"The 'feed_order' should be either None, list or dict.")
288+
if not sorted(feed_order.values()) == range(len(feed_order)):
289+
raise ValueError(
290+
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
291+
)
292+
sorted_pair_list = sorted(feed_order.items(), key=lambda item: item[1])
293+
feed_var_list = [
294+
program.global_block().var(pair[0]) for pair in sorted_pair_list
295+
]
296+
return feed_var_list

0 commit comments

Comments
 (0)