22# github.com/deadbits/vector-embedding-api
33import os
44import sys
5+ import argparse
56import logging
67import configparser
78
1314
1415app = Flask (__name__ )
1516
17+ logging .basicConfig (level = logging .INFO )
18+ logger = logging .getLogger (__name__ )
19+
1620
1721class Config :
1822 def __init__ (self , config_file ):
@@ -30,13 +34,13 @@ def get(self, section, key):
3034
3135 try :
3236 answer = self .config .get (section , key )
33- except :
34- logging .error (f'Config file missing section: { section } ' )
37+ except Exception as err :
38+ logging .error (f'Config file missing section: { section } - { err } ' )
3539
3640 return answer
3741
3842
39- def get_openai_embeddings (text : str ):
43+ def get_openai_embeddings (text : str ) -> list :
4044 try :
4145 response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
4246 return response ['data' ][0 ]['embedding' ]
@@ -45,7 +49,7 @@ def get_openai_embeddings(text: str):
4549 abort (500 , 'Failed to get OpenAI embeddings' )
4650
4751
48- def get_transformers_embeddings (text : str ):
52+ def get_transformers_embeddings (text : str ) -> list :
4953 try :
5054 return model .encode (text ).tolist ()
5155 except Exception as err :
@@ -56,26 +60,48 @@ def get_transformers_embeddings(text: str):
5660@app .route ('/submit' , methods = ['POST' ])
5761def submit_text ():
5862 data = request .json
59- ada = data .get ('ada' , False )
6063
61- if not 'text' in data :
62- abort (400 , 'Text data is required' )
64+ text_data = data .get ('text' )
65+ model_type = data .get ('model' , 'local' ).lower ()
66+
67+ if text_data is None :
68+ abort (400 , 'Missing text data to embed' )
69+
70+ if model_type not in ['local' , 'openai' ]:
71+ abort (400 , 'model field must be one of: local, openai' )
6372
64- if ada :
65- embedding_data = get_openai_embeddings (data [ 'text' ] )
73+ if model_type == 'openai' :
74+ embedding_data = get_openai_embeddings (text_data )
6675 else :
67- embedding_data = get_transformers_embeddings (data [ 'text' ] )
76+ embedding_data = get_transformers_embeddings (text_data )
6877
6978 return jsonify ({'embedding' : embedding_data , 'status' : 'success' })
7079
7180
7281if __name__ == '__main__' :
73- conf = Config ('server.conf' )
74- openai .api_key = conf .get ('main' , 'openai_api_key' )
82+ parser = argparse .ArgumentParser ()
83+
84+ parser .add_argument (
85+ '-c' , '--config' ,
86+ help = 'config file' ,
87+ type = str ,
88+ required = True
89+ )
90+
91+ args = parser .parse_args ()
92+
93+ conf = Config (args .config )
94+ api_key = conf .get ('main' , 'openai_api_key' )
7595 sent_model = conf .get ('main' , 'sent_transformers_model' )
7696
77- logging .basicConfig (level = logging .INFO )
78- logger = logging .getLogger (__name__ )
97+ if api_key is None :
98+ logger .warn ('No OpenAI API key set in configuration file: server.conf' )
99+ else :
100+ logger .info ('Set OpenAI API key via openai.api_key' )
101+ openai .api_key = api_key
102+
103+ if sent_model is None :
104+ logger .warn ('No transformer model set in configuration file: server.conf' )
79105
80106 try :
81107 model = SentenceTransformer (sent_model )
0 commit comments