Skip to content

Commit dd894c2

Browse files
authored
Merge pull request #876 from jacquesqiao/develop
add python api_predict for quick start
2 parents 638cf8d + 0c7d553 commit dd894c2

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

demo/quick_start/api_predict.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os, sys
16+
import numpy as np
17+
from optparse import OptionParser
18+
from py_paddle import swig_paddle, DataProviderConverter
19+
from paddle.trainer.PyDataProvider2 import sparse_binary_vector
20+
from paddle.trainer.config_parser import parse_config
21+
"""
22+
Usage: run following command to show help message.
23+
python api_predict.py -h
24+
"""
25+
26+
27+
class QuickStartPrediction():
28+
def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
29+
"""
30+
train_conf: trainer configure.
31+
dict_file: word dictionary file name.
32+
model_dir: directory of model.
33+
"""
34+
self.train_conf = train_conf
35+
self.dict_file = dict_file
36+
self.word_dict = {}
37+
self.dict_dim = self.load_dict()
38+
self.model_dir = model_dir
39+
if model_dir is None:
40+
self.model_dir = os.path.dirname(train_conf)
41+
42+
self.label = None
43+
if label_file is not None:
44+
self.load_label(label_file)
45+
46+
conf = parse_config(train_conf, "is_predict=1")
47+
self.network = swig_paddle.GradientMachine.createFromConfigProto(
48+
conf.model_config)
49+
self.network.loadParameters(self.model_dir)
50+
input_types = [sparse_binary_vector(self.dict_dim)]
51+
self.converter = DataProviderConverter(input_types)
52+
53+
def load_dict(self):
54+
"""
55+
Load dictionary from self.dict_file.
56+
"""
57+
for line_count, line in enumerate(open(self.dict_file, 'r')):
58+
self.word_dict[line.strip().split('\t')[0]] = line_count
59+
return len(self.word_dict)
60+
61+
def load_label(self, label_file):
62+
"""
63+
Load label.
64+
"""
65+
self.label = {}
66+
for v in open(label_file, 'r'):
67+
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
68+
69+
def get_index(self, data):
70+
"""
71+
transform word into integer index according to the dictionary.
72+
"""
73+
words = data.strip().split()
74+
word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
75+
return word_slot
76+
77+
def batch_predict(self, data_batch):
78+
input = self.converter(data_batch)
79+
output = self.network.forwardTest(input)
80+
prob = output[0]["id"].tolist()
81+
print("predicting labels is:")
82+
print prob
83+
84+
85+
def option_parser():
86+
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
87+
parser = OptionParser(usage="usage: %s [options]" % usage)
88+
parser.add_option(
89+
"-n",
90+
"--tconf",
91+
action="store",
92+
dest="train_conf",
93+
help="network config")
94+
parser.add_option(
95+
"-d",
96+
"--dict",
97+
action="store",
98+
dest="dict_file",
99+
help="dictionary file")
100+
parser.add_option(
101+
"-b",
102+
"--label",
103+
action="store",
104+
dest="label",
105+
default=None,
106+
help="dictionary file")
107+
parser.add_option(
108+
"-c",
109+
"--batch_size",
110+
type="int",
111+
action="store",
112+
dest="batch_size",
113+
default=1,
114+
help="the batch size for prediction")
115+
parser.add_option(
116+
"-w",
117+
"--model",
118+
action="store",
119+
dest="model_path",
120+
default=None,
121+
help="model path")
122+
return parser.parse_args()
123+
124+
125+
def main():
126+
options, args = option_parser()
127+
train_conf = options.train_conf
128+
batch_size = options.batch_size
129+
dict_file = options.dict_file
130+
model_path = options.model_path
131+
label = options.label
132+
swig_paddle.initPaddle("--use_gpu=0")
133+
predict = QuickStartPrediction(train_conf, dict_file, model_path, label)
134+
135+
batch = []
136+
labels = []
137+
for line in sys.stdin:
138+
[label, text] = line.split("\t")
139+
labels.append(int(label))
140+
batch.append([predict.get_index(text)])
141+
print("labels is:")
142+
print labels
143+
predict.batch_predict(batch)
144+
145+
146+
if __name__ == '__main__':
147+
main()

demo/quick_start/api_predict.sh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
set -e
16+
17+
#Note the default model is pass-00002, you shold make sure the model path
18+
#exists or change the mode path.
19+
#only test on trainer_config.lr.py
20+
model=output/pass-00001/
21+
config=trainer_config.lr.py
22+
label=data/labels.list
23+
dict=data/dict.txt
24+
batch_size=20
25+
head -n$batch_size data/test.txt | python api_predict.py \
26+
--tconf=$config\
27+
--model=$model \
28+
--label=$label \
29+
--dict=$dict \
30+
--batch_size=$batch_size

0 commit comments

Comments
 (0)