Skip to content

Commit 552b2e4

Browse files
committed
API Train for quick start
1 parent 5383317 commit 552b2e4

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

demo/quick_start/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ data/pred.txt
1313
dataprovider_copy_1.py
1414
train.log
1515
output
16+
*.w0
17+
*.wbias
18+
*.pkl

demo/quick_start/api_train_gm.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import random
2+
import cPickle
3+
import os
4+
import paddle.v2 as paddle
5+
6+
7+
class FileReader(object):
8+
"""
9+
:type word_dict: dict
10+
:type __pool__: list
11+
"""
12+
13+
def __init__(self, word_dict, filename, batch_size, should_shuffle=True):
14+
if isinstance(word_dict, basestring):
15+
self.word_dict = FileReader.read_from_dict(word_dict)
16+
else:
17+
self.word_dict = word_dict
18+
self.__should_shuffle__ = should_shuffle
19+
self.__batch_size__ = batch_size
20+
21+
self.__pool__ = self.load_all_data(filename)
22+
self.__idx__ = 0
23+
24+
def load_all_data(self, filename):
25+
def __mapper__(line):
26+
label, sentence = line.split('\t')
27+
label = int(label)
28+
word_ids = filter(lambda x: x is not None,
29+
map(lambda x: self.word_dict.get(x, None),
30+
sentence.split()))
31+
return word_ids, label
32+
33+
if filename[-3:] == 'txt':
34+
with open(filename, 'r') as f:
35+
ret_val = map(__mapper__, f)
36+
with open("%s.pkl" % filename[:-4], 'wb') as f:
37+
cPickle.dump(ret_val, f, cPickle.HIGHEST_PROTOCOL)
38+
return ret_val
39+
elif filename[-3:] == 'pkl':
40+
with open(filename, 'rb') as f:
41+
return cPickle.load(f)
42+
43+
def __iter__(self):
44+
self.reset()
45+
return self
46+
47+
def reset(self):
48+
if self.__should_shuffle__:
49+
random.shuffle(self.__pool__)
50+
self.__idx__ = 0
51+
52+
def next(self):
53+
if self.__idx__ < len(self.__pool__):
54+
end = min(self.__idx__ + self.__batch_size__, len(self.__pool__))
55+
start = self.__idx__
56+
self.__idx__ = end
57+
return self.__pool__[start:end]
58+
else:
59+
raise StopIteration()
60+
61+
@staticmethod
62+
def read_from_dict(fn):
63+
if os.path.exists(fn + '.pkl'):
64+
with open(fn + '.pkl', 'rb') as f:
65+
return cPickle.load(f)
66+
else:
67+
ret_val = dict()
68+
with open(fn, 'r') as f:
69+
for i, line in enumerate(f):
70+
w = line.split()[0]
71+
ret_val[w] = i
72+
with open(fn + '.pkl', 'wb') as f:
73+
cPickle.dump(ret_val, f, cPickle.HIGHEST_PROTOCOL)
74+
return ret_val
75+
76+
77+
def optimizer_config():
78+
paddle.config.settings(
79+
batch_size=1,
80+
learning_rate=1e-4,
81+
learning_method=paddle.config.RMSPropOptimizer())
82+
83+
84+
def bow_config(dict_size):
85+
def __impl__():
86+
sentence = paddle.config.data_layer(name='sentence', size=dict_size)
87+
inference = paddle.config.fc_layer(
88+
input=sentence,
89+
size=2,
90+
act=paddle.config.SoftmaxActivation(),
91+
param_attr=paddle.config.ParamAttr(sparse_update=True))
92+
cost = paddle.config.classification_cost(
93+
input=inference,
94+
label=paddle.config.data_layer(
95+
name='label', size=2))
96+
paddle.config.outputs(cost)
97+
98+
return __impl__
99+
100+
101+
def swap_batch(batch):
102+
for each_item in batch:
103+
a, b = each_item
104+
yield b, a
105+
106+
107+
def main():
108+
print 'Loading data into memory'
109+
train_file_name = './data/train.pkl' if os.path.exists(
110+
'./data/train.pkl') else './data/train.txt'
111+
112+
test_file_name = './data/test.pkl' if os.path.exists(
113+
'./data/test.pkl') else './data/test.txt'
114+
115+
train_reader = FileReader(
116+
"./data/dict.txt", filename=train_file_name, batch_size=1024)
117+
test_reader = FileReader(
118+
train_reader.word_dict, filename=test_file_name, batch_size=1024)
119+
120+
print 'Done.'
121+
122+
paddle.raw.initPaddle('--use_gpu=0', '--trainer_count=3')
123+
124+
optimizer_proto = paddle.config.parse_optimizer(
125+
optimizer_conf=optimizer_config)
126+
optimizer_conf = paddle.raw.OptimizationConfig.createFromProto(
127+
optimizer_proto)
128+
__tmp_optimizer__ = paddle.raw.ParameterOptimizer.create(optimizer_conf)
129+
assert isinstance(__tmp_optimizer__, paddle.raw.ParameterOptimizer)
130+
enable_types = __tmp_optimizer__.getParameterTypes()
131+
132+
model_proto = paddle.config.parse_network(
133+
network_conf=bow_config(len(train_reader.word_dict)))
134+
135+
for param in model_proto.parameters:
136+
if param.sparse_remote_update:
137+
# disable sparse remote update, when local
138+
param.sparse_remote_update = False
139+
140+
gradient_machine = paddle.raw.GradientMachine.createFromConfigProto(
141+
model_proto, paddle.raw.CREATE_MODE_NORMAL, enable_types)
142+
assert isinstance(gradient_machine, paddle.raw.GradientMachine)
143+
gradient_machine.randParameters()
144+
145+
updater = paddle.raw.ParameterUpdater.createLocalUpdater(optimizer_conf)
146+
assert isinstance(updater, paddle.raw.ParameterUpdater)
147+
148+
input_order = model_proto.input_layer_names
149+
input_types = {
150+
'sentence':
151+
paddle.data.sparse_binary_vector(len(train_reader.word_dict)),
152+
'label': paddle.data.integer_value(2)
153+
}
154+
155+
tmp = []
156+
for each in input_order:
157+
tmp.append(input_types[each])
158+
159+
input_types = tmp
160+
161+
converter = paddle.data.DataProviderConverter(input_types=input_types)
162+
163+
input_order_for_data = ['sentence', 'label']
164+
switcher = None
165+
if input_order_for_data != input_order:
166+
switcher = swap_batch
167+
168+
updater.init(gradient_machine)
169+
170+
gradient_machine.start()
171+
172+
train_evaluator = gradient_machine.makeEvaluator()
173+
test_evaluator = gradient_machine.makeEvaluator()
174+
assert isinstance(train_evaluator, paddle.raw.Evaluator)
175+
assert isinstance(test_evaluator, paddle.raw.Evaluator)
176+
177+
train_evaluate_period = 100
178+
179+
out_args = paddle.raw.Arguments.createArguments(0)
180+
assert isinstance(out_args, paddle.raw.Arguments)
181+
for pass_id in xrange(10):
182+
updater.startPass()
183+
for batch_id, data_batch in enumerate(train_reader):
184+
if switcher is not None:
185+
data_batch = switcher(data_batch)
186+
187+
updater.startBatch(len(data_batch))
188+
189+
in_args = converter(data_batch)
190+
191+
if batch_id % train_evaluate_period == 0:
192+
train_evaluator.start()
193+
194+
gradient_machine.forwardBackward(in_args, out_args,
195+
paddle.raw.PASS_TRAIN)
196+
197+
gradient_machine.eval(train_evaluator)
198+
199+
cost = out_args.sumCosts() / len(data_batch)
200+
201+
if batch_id % train_evaluate_period == 0:
202+
print 'Pass=%d Batch=%d Cost=%f' % (pass_id, batch_id,
203+
cost), train_evaluator
204+
train_evaluator.finish()
205+
206+
gradient_machine.eval(train_evaluator)
207+
208+
for each_param in gradient_machine.getParameters():
209+
updater.update(each_param)
210+
211+
updater.finishBatch(cost)
212+
213+
print 'Pass=%d Batch=%d Cost=%f' % (pass_id, batch_id,
214+
cost), train_evaluator
215+
updater.catchUpWith()
216+
217+
test_evaluator.start()
218+
for data_batch in test_reader:
219+
if switcher is not None:
220+
data_batch = switcher(data_batch)
221+
222+
in_args = converter(data_batch)
223+
gradient_machine.forward(in_args, out_args, paddle.raw.PASS_TEST)
224+
gradient_machine.eval(test_evaluator)
225+
226+
print 'Test Pass=%d' % pass_id, test_evaluator
227+
228+
print 'Saving parameters.'
229+
for param in gradient_machine.getParameters():
230+
assert isinstance(param, paddle.raw.Parameter)
231+
save_name = "%d_%s" % (pass_id, param.getName())
232+
param.save(save_name)
233+
print 'Done.'
234+
235+
test_evaluator.finish()
236+
237+
updater.finishPass()
238+
gradient_machine.finish()
239+
240+
241+
if __name__ == '__main__':
242+
main()

0 commit comments

Comments
 (0)