Skip to content

Commit f04ae97

Browse files
authored
Merge pull request #12161 from JiayiFeng/make_get_test_program_private
Remove buggy get_test_program and refine reader demo
2 parents 7b63b85 + 0388d1c commit f04ae97

File tree

5 files changed

+140
-246
lines changed

5 files changed

+140
-246
lines changed

python/paddle/fluid/io.py

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -790,101 +790,3 @@ def get_parameter_value_by_name(name, executor, program=None):
790790
program = default_main_program()
791791
var = program.global_block().var(name)
792792
return get_parameter_value(var, executor)
793-
794-
795-
def get_test_program(filelist, program=None, startup_program=None):
796-
"""
797-
Transpile current train program to a program to read test dataset
798-
if the program is using reader ops like "open_files_op".
799-
"""
800-
801-
def _copy_reader_var_(block, var, new_name=None):
802-
if new_name == None:
803-
new_name = var.name
804-
new_var = block.create_var(
805-
name=str(new_name), type=core.VarDesc.VarType.READER)
806-
new_var.desc.set_shapes(var.desc.shapes())
807-
new_var.desc.set_dtypes(var.desc.dtypes())
808-
new_var.persistable = True
809-
return new_var
810-
811-
def _get_test_reader_name(train_reader_name):
812-
return train_reader_name + "_test"
813-
814-
def _is_reader_op(op):
815-
block = op.block
816-
if "Out" in op.output_names:
817-
reader_out = block.vars[op.output("Out")[0]]
818-
if reader_out.type == core.VarDesc.VarType.READER:
819-
return True
820-
return False
821-
822-
if program == None:
823-
program = default_main_program()
824-
if startup_program == None:
825-
startup_program = default_startup_program()
826-
startup_block = startup_program.global_block()
827-
828-
# 1. find out the orignal reader var name
829-
startup_reader_op_list = []
830-
831-
for op in startup_block.ops:
832-
if _is_reader_op(op):
833-
startup_reader_op_list.append(op)
834-
835-
if len(startup_reader_op_list) == 0:
836-
return program
837-
838-
root_reader_op = startup_reader_op_list[0]
839-
train_test_reader_map = {}
840-
# 2. add operators to startup to read open and read test data files
841-
for op in startup_reader_op_list:
842-
assert (len(op.output("Out")) == 1)
843-
train_reader_name = op.output("Out")[0]
844-
train_reader = startup_block.vars[train_reader_name]
845-
test_reader = _copy_reader_var_(
846-
startup_block,
847-
train_reader,
848-
new_name=_get_test_reader_name(train_reader_name))
849-
train_test_reader_map[train_reader.name] = test_reader
850-
851-
test_op_inputs = {}
852-
for name in op.input_names:
853-
train_arg_names = op.input(name)
854-
test_arg_vars = []
855-
for arg_name in train_arg_names:
856-
arg_var = train_test_reader_map[
857-
arg_name] if name == "UnderlyingReader" else startup_block.vars[
858-
arg_name]
859-
test_arg_vars.append(arg_var)
860-
test_op_inputs[name] = test_arg_vars
861-
862-
test_op = startup_block.append_op(
863-
type=op.type,
864-
inputs=test_op_inputs,
865-
outputs={'Out': [test_reader]},
866-
attrs=op.attrs)
867-
# root reader op's filelist attr for read test files
868-
if op.type == root_reader_op.type:
869-
test_op.set_attr("file_names", filelist)
870-
if op.type == "create_multi_pass_reader":
871-
test_op.set_attr("pass_num", 1)
872-
873-
# 3. rename reader vars in inference program to different name
874-
# to avoid read from train data.
875-
main_block = program.global_block()
876-
for var in main_block.vars.values():
877-
if var.type == core.VarDesc.VarType.READER:
878-
main_block._rename_var(
879-
str(var.name), str(_get_test_reader_name(var.name)))
880-
881-
for op in main_block.ops:
882-
if op.type == root_reader_op.type:
883-
test_op.set_attr("file_names", filelist)
884-
if op.type == "create_multi_pass_reader":
885-
test_op.set_attr("pass_num", 1)
886-
887-
startup_program._sync_with_cpp()
888-
program._sync_with_cpp()
889-
890-
return program

python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py renamed to python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def load_vocab(filename):
3535
word_dict = paddle.dataset.imdb.word_dict()
3636
else:
3737
word_dict = load_vocab(sys.argv[1])
38-
word_dict["<unk>"] = len(word_dict)
38+
word_dict["<unk>"] = len(word_dict)
3939
print "Dict dim = ", len(word_dict)
4040

4141
# input text data
@@ -50,7 +50,7 @@ def load_vocab(filename):
5050
BATCH_SIZE = 128
5151
train_reader = paddle.batch(
5252
paddle.reader.shuffle(
53-
paddle.dataset.imdb.train(word_dict), buf_size=10000),
53+
paddle.dataset.imdb.train(word_dict), buf_size=25000),
5454
batch_size=BATCH_SIZE)
5555

5656
test_reader = paddle.batch(
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.fluid as fluid
16+
import numpy
17+
import sys
18+
19+
TRAIN_FILES = ['train.recordio']
20+
TEST_FILES = ['test.recordio']
21+
22+
DICT_DIM = 5147
23+
24+
# embedding dim
25+
emb_dim = 128
26+
27+
# hidden dim
28+
hid_dim = 128
29+
30+
# class num
31+
class_dim = 2
32+
33+
# epoch num
34+
epoch_num = 10
35+
36+
37+
def build_program(is_train):
38+
file_obj_handle = fluid.layers.io.open_files(
39+
filenames=TRAIN_FILES if is_train else TEST_FILES,
40+
shapes=[[-1, 1], [-1, 1]],
41+
lod_levels=[1, 0],
42+
dtypes=['int64', 'int64'])
43+
44+
file_obj = fluid.layers.io.double_buffer(file_obj_handle)
45+
46+
with fluid.unique_name.guard():
47+
48+
data, label = fluid.layers.read_file(file_obj)
49+
50+
emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim])
51+
52+
conv_3 = fluid.nets.sequence_conv_pool(
53+
input=emb,
54+
num_filters=hid_dim,
55+
filter_size=3,
56+
act="tanh",
57+
pool_type="sqrt")
58+
59+
conv_4 = fluid.nets.sequence_conv_pool(
60+
input=emb,
61+
num_filters=hid_dim,
62+
filter_size=4,
63+
act="tanh",
64+
pool_type="sqrt")
65+
66+
prediction = fluid.layers.fc(input=[conv_3, conv_4],
67+
size=class_dim,
68+
act="softmax")
69+
70+
# cross entropy loss
71+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
72+
73+
# mean loss
74+
avg_cost = fluid.layers.mean(x=cost)
75+
acc = fluid.layers.accuracy(input=prediction, label=label)
76+
77+
if is_train:
78+
# SGD optimizer
79+
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.001)
80+
sgd_optimizer.minimize(avg_cost)
81+
82+
return {'loss': avg_cost, 'log': [avg_cost, acc], 'file': file_obj_handle}
83+
84+
85+
def main():
86+
train = fluid.Program()
87+
startup = fluid.Program()
88+
test = fluid.Program()
89+
90+
with fluid.program_guard(train, startup):
91+
train_args = build_program(is_train=True)
92+
93+
with fluid.program_guard(test, startup):
94+
test_args = build_program(is_train=False)
95+
96+
use_cuda = fluid.core.is_compiled_with_cuda()
97+
# startup
98+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
99+
exe = fluid.Executor(place=place)
100+
exe.run(startup)
101+
102+
train_exe = fluid.ParallelExecutor(
103+
use_cuda=use_cuda,
104+
loss_name=train_args['loss'].name,
105+
main_program=train)
106+
test_exe = fluid.ParallelExecutor(
107+
use_cuda=use_cuda, main_program=test, share_vars_from=train_exe)
108+
109+
fetch_var_list = [var.name for var in train_args['log']]
110+
for epoch_id in range(epoch_num):
111+
# train
112+
try:
113+
batch_id = 0
114+
while True:
115+
loss, acc = map(numpy.array,
116+
train_exe.run(fetch_list=fetch_var_list))
117+
print 'Train epoch', epoch_id, 'batch', batch_id, 'loss:', loss, 'acc:', acc
118+
batch_id += 1
119+
except fluid.core.EOFException:
120+
print 'End of epoch', epoch_id
121+
train_args['file'].reset()
122+
123+
# test
124+
loss = []
125+
acc = []
126+
try:
127+
while True:
128+
loss_np, acc_np = map(numpy.array,
129+
test_exe.run(fetch_list=fetch_var_list))
130+
loss.append(loss_np[0])
131+
acc.append(acc_np[0])
132+
except:
133+
test_args['file'].reset()
134+
print 'Test loss:', numpy.mean(loss), 'acc:', numpy.mean(acc)
135+
136+
137+
if __name__ == '__main__':
138+
main()

0 commit comments

Comments
 (0)