Skip to content

Commit 77130c3

Browse files
committed
add python api_predict for quick start
1 parent c5d5ad3 commit 77130c3

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

demo/quick_start/api_predict.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
from py_paddle import swig_paddle, DataProviderConverter
1919
from paddle.trainer.PyDataProvider2 import sparse_binary_vector
2020
from paddle.trainer.config_parser import parse_config
21-
22-
2321
"""
2422
Usage: run following command to show help message.
2523
python api_predict.py -h
2624
"""
2725

26+
2827
class QuickStartPrediction():
2928
def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
3029
"""
@@ -72,9 +71,7 @@ def get_index(self, data):
7271
transform word into integer index according to the dictionary.
7372
"""
7473
words = data.strip().split()
75-
word_slot = [
76-
self.word_dict[w] for w in words if w in self.word_dict
77-
]
74+
word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
7875
return word_slot
7976

8077
def batch_predict(self, data_batch):
@@ -84,6 +81,7 @@ def batch_predict(self, data_batch):
8481
print("predicting labels is:")
8582
print prob
8683

84+
8785
def option_parser():
8886
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
8987
parser = OptionParser(usage="usage: %s [options]" % usage)
@@ -144,5 +142,6 @@ def main():
144142
print labels
145143
predict.batch_predict(batch)
146144

145+
147146
if __name__ == '__main__':
148147
main()

0 commit comments

Comments
 (0)