Skip to content

Commit c5d5ad3

Browse files
committed
add python api_predict for quick start
1 parent b0c6331 commit c5d5ad3

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed

demo/quick_start/api_predict.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os, sys
16+
import numpy as np
17+
from optparse import OptionParser
18+
from py_paddle import swig_paddle, DataProviderConverter
19+
from paddle.trainer.PyDataProvider2 import sparse_binary_vector
20+
from paddle.trainer.config_parser import parse_config
21+
22+
23+
"""
24+
Usage: run following command to show help message.
25+
python api_predict.py -h
26+
"""
27+
28+
class QuickStartPrediction():
29+
def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
30+
"""
31+
train_conf: trainer configure.
32+
dict_file: word dictionary file name.
33+
model_dir: directory of model.
34+
"""
35+
self.train_conf = train_conf
36+
self.dict_file = dict_file
37+
self.word_dict = {}
38+
self.dict_dim = self.load_dict()
39+
self.model_dir = model_dir
40+
if model_dir is None:
41+
self.model_dir = os.path.dirname(train_conf)
42+
43+
self.label = None
44+
if label_file is not None:
45+
self.load_label(label_file)
46+
47+
conf = parse_config(train_conf, "is_predict=1")
48+
self.network = swig_paddle.GradientMachine.createFromConfigProto(
49+
conf.model_config)
50+
self.network.loadParameters(self.model_dir)
51+
input_types = [sparse_binary_vector(self.dict_dim)]
52+
self.converter = DataProviderConverter(input_types)
53+
54+
def load_dict(self):
55+
"""
56+
Load dictionary from self.dict_file.
57+
"""
58+
for line_count, line in enumerate(open(self.dict_file, 'r')):
59+
self.word_dict[line.strip().split('\t')[0]] = line_count
60+
return len(self.word_dict)
61+
62+
def load_label(self, label_file):
63+
"""
64+
Load label.
65+
"""
66+
self.label = {}
67+
for v in open(label_file, 'r'):
68+
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
69+
70+
def get_index(self, data):
71+
"""
72+
transform word into integer index according to the dictionary.
73+
"""
74+
words = data.strip().split()
75+
word_slot = [
76+
self.word_dict[w] for w in words if w in self.word_dict
77+
]
78+
return word_slot
79+
80+
def batch_predict(self, data_batch):
81+
input = self.converter(data_batch)
82+
output = self.network.forwardTest(input)
83+
prob = output[0]["id"].tolist()
84+
print("predicting labels is:")
85+
print prob
86+
87+
def option_parser():
88+
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
89+
parser = OptionParser(usage="usage: %s [options]" % usage)
90+
parser.add_option(
91+
"-n",
92+
"--tconf",
93+
action="store",
94+
dest="train_conf",
95+
help="network config")
96+
parser.add_option(
97+
"-d",
98+
"--dict",
99+
action="store",
100+
dest="dict_file",
101+
help="dictionary file")
102+
parser.add_option(
103+
"-b",
104+
"--label",
105+
action="store",
106+
dest="label",
107+
default=None,
108+
help="dictionary file")
109+
parser.add_option(
110+
"-c",
111+
"--batch_size",
112+
type="int",
113+
action="store",
114+
dest="batch_size",
115+
default=1,
116+
help="the batch size for prediction")
117+
parser.add_option(
118+
"-w",
119+
"--model",
120+
action="store",
121+
dest="model_path",
122+
default=None,
123+
help="model path")
124+
return parser.parse_args()
125+
126+
127+
def main():
128+
options, args = option_parser()
129+
train_conf = options.train_conf
130+
batch_size = options.batch_size
131+
dict_file = options.dict_file
132+
model_path = options.model_path
133+
label = options.label
134+
swig_paddle.initPaddle("--use_gpu=0")
135+
predict = QuickStartPrediction(train_conf, dict_file, model_path, label)
136+
137+
batch = []
138+
labels = []
139+
for line in sys.stdin:
140+
[label, text] = line.split("\t")
141+
labels.append(int(label))
142+
batch.append([predict.get_index(text)])
143+
print("lables is:")
144+
print labels
145+
predict.batch_predict(batch)
146+
147+
if __name__ == '__main__':
148+
main()

demo/quick_start/api_predict.sh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
set -e
16+
17+
#Note the default model is pass-00002, you shold make sure the model path
18+
#exists or change the mode path.
19+
#only test on trainer_config.lr.py
20+
model=output/pass-00001/
21+
config=trainer_config.lr.py
22+
label=data/labels.list
23+
dict=data/dict.txt
24+
batch_size=20
25+
head -n$batch_size data/test.txt | python api_predict.py \
26+
--tconf=$config\
27+
--model=$model \
28+
--label=$label \
29+
--dict=$dict \
30+
--batch_size=$batch_size

0 commit comments

Comments
 (0)