11#!/usr/bin/env python3
22# github.com/deadbits/vector-embedding-api
3+ # server.py
34import os
45import sys
6+ import time
57import argparse
8+ import hashlib
69import logging
710import configparser
811
@@ -29,7 +32,7 @@ def __init__(self, config_file):
2932 logging .info (f'Loading config file: { self .config_file } ' )
3033 self .config .read (config_file )
3134
32- def get (self , section , key ):
35+ def get_val (self , section , key ):
3336 answer = None
3437
3538 try :
@@ -38,44 +41,115 @@ def get(self, section, key):
3841 logging .error (f'Config file missing section: { section } - { err } ' )
3942
4043 return answer
44+
45+ def get_bool (self , section , key , default = False ):
46+ try :
47+ return self .config .getboolean (section , key )
48+ except Exception as err :
49+ logging .error (f'Failed to parse boolean - returning default "False": { section } - { err } ' )
50+ return default
51+
52+
53+ class EmbeddingCache :
54+ def __init__ (self ):
55+ logger .info ('Created in-memory cache' )
56+ self .cache = {}
57+
58+ def get_cache_key (self , text , model_type ):
59+ return hashlib .sha256 ((text + model_type ).encode ()).hexdigest ()
60+
61+ def get (self , text , model_type ):
62+ return self .cache .get (self .get_cache_key (text , model_type ))
4163
64+ def set (self , text , model_type , embedding ):
65+ self .cache [self .get_cache_key (text , model_type )] = embedding
4266
43- def get_openai_embeddings (text : str ) -> list :
44- try :
45- response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
46- return response ['data' ][0 ]['embedding' ]
47- except Exception as err :
48- logger .error (f'Failed to get OpenAI embeddings: { err } ' )
49- abort (500 , 'Failed to get OpenAI embeddings' )
5067
68+ class EmbeddingGenerator :
69+ def __init__ (self , sbert_model = None , openai_key = None ):
70+ self .sbert_model = sbert_model
71+ if self .sbert_model is not None :
72+ try :
73+ self .model = SentenceTransformer (self .sbert_model )
74+ logger .info (f'enabled model: { self .sbert_model } ' )
75+ except Exception as err :
76+ logger .error (f'Failed to load SentenceTransformer model "{ self .sbert_model } ": { err } ' )
77+ sys .exit (1 )
5178
52- def get_transformers_embeddings (text : str ) -> list :
53- try :
54- return model .encode (text ).tolist ()
55- except Exception as err :
56- logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
57- abort (500 , 'Failed to get sentence-transformers embeddings' )
79+ if openai_key is not None :
80+ openai .api_key = openai_key
81+ logger .info ('enabled model: text-embedding-ada-002' )
82+
83+ def get_openai_embeddings (self , text ):
84+ start_time = time .time ()
85+
86+ try :
87+ response = openai .Embedding .create (input = text , model = 'text-embedding-ada-002' )
88+ elapsed_time = (time .time () - start_time ) * 1000
89+ data = {
90+ "embedding" : response ['data' ][0 ]['embedding' ],
91+ "status" : "success" ,
92+ "elapsed" : elapsed_time ,
93+ "model" : "text-embedding-ada-002"
94+ }
95+ return data
96+ except Exception as err :
97+ logger .error (f'Failed to get OpenAI embeddings: { err } ' )
98+ return {"status" : "error" , "message" : str (err ), "model" : "text-embedding-ada-002" }
99+
100+ def get_transformers_embeddings (self , text ):
101+ start_time = time .time ()
102+
103+ try :
104+ embedding = self .model .encode (text ).tolist ()
105+ elapsed_time = (time .time () - start_time ) * 1000
106+ data = {
107+ "embedding" : embedding ,
108+ "status" : "success" ,
109+ "elapsed" : elapsed_time ,
110+ "model" : self .sbert_model
111+ }
112+ return data
113+ except Exception as err :
114+ logger .error (f'Failed to get sentence-transformers embeddings: { err } ' )
115+ return {"status" : "error" , "message" : str (err ), "model" : self .sbert_model }
116+
117+ def generate (self , text , model_type ):
118+ if model_type == 'openai' :
119+ return self .get_openai_embeddings (text )
120+ else :
121+ return self .get_transformers_embeddings (text )
58122
59123
60124@app .route ('/submit' , methods = ['POST' ])
61125def submit_text ():
62126 data = request .json
63-
127+
64128 text_data = data .get ('text' )
65129 model_type = data .get ('model' , 'local' ).lower ()
66130
67131 if text_data is None :
68132 abort (400 , 'Missing text data to embed' )
69-
133+
70134 if model_type not in ['local' , 'openai' ]:
71135 abort (400 , 'model field must be one of: local, openai' )
72136
73- if model_type == 'openai' :
74- embedding_data = get_openai_embeddings (text_data )
137+ if embedding_cache :
138+ result = embedding_cache .get (text_data , model_type )
139+ if result :
140+ logger .info ('found embedding in cache!' )
141+ result = {'embedding' : result , 'cache' : True , "status" : 'success' }
75142 else :
76- embedding_data = get_transformers_embeddings (text_data )
143+ result = None
144+
145+ if result is None :
146+ result = embedding_generator .generate (text_data , model_type )
147+
148+ if embedding_cache and result ['status' ] == 'success' :
149+ embedding_cache .set (text_data , model_type , result ['embedding' ])
150+ logger .info ('added to cache' )
77151
78- return jsonify ({ 'embedding' : embedding_data , 'status' : 'success' } )
152+ return jsonify (result )
79153
80154
81155if __name__ == '__main__' :
@@ -91,23 +165,21 @@ def submit_text():
91165 args = parser .parse_args ()
92166
93167 conf = Config (args .config )
94- api_key = conf .get ('main' , 'openai_api_key' )
95- sent_model = conf .get ('main' , 'sent_transformers_model' )
168+ openai_key = conf .get_val ('main' , 'openai_api_key' )
169+ sbert_model = conf .get_val ('main' , 'sent_transformers_model' )
170+ use_cache = conf .get_bool ('main' , 'use_cache' , default = False )
96171
97- if api_key is None :
172+ if openai_key is None :
98173 logger .warn ('No OpenAI API key set in configuration file: server.conf' )
99- else :
100- logger .info ('Set OpenAI API key via openai.api_key' )
101- openai .api_key = api_key
102174
103- if sent_model is None :
175+ if sbert_model is None :
104176 logger .warn ('No transformer model set in configuration file: server.conf' )
105177
106- try :
107- model = SentenceTransformer (sent_model )
108- except Exception as err :
109- logger .error (f'Failed to load SentenceTransformer model "{ sent_model } ": { err } ' )
178+ if openai_key is None and sbert_model is None :
179+ logger .error ('No sbert model set *and* no openAI key set; exiting' )
110180 sys .exit (1 )
111181
112- app .run (debug = True )
182+ embedding_cache = EmbeddingCache () if use_cache else None
183+ embedding_generator = EmbeddingGenerator (sbert_model , openai_key )
113184
185+ app .run (debug = True )
0 commit comments