Skip to content

Commit 7fad723

Browse files
committed
added caching; separated everything into classes; enriched api response
1 parent 4300acb commit 7fad723

File tree

1 file changed

+104
-32
lines changed

1 file changed

+104
-32
lines changed

server.py

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#!/usr/bin/env python3
22
# github.com/deadbits/vector-embedding-api
3+
# server.py
34
import os
45
import sys
6+
import time
57
import argparse
8+
import hashlib
69
import logging
710
import configparser
811

@@ -29,7 +32,7 @@ def __init__(self, config_file):
2932
logging.info(f'Loading config file: {self.config_file}')
3033
self.config.read(config_file)
3134

32-
def get(self, section, key):
35+
def get_val(self, section, key):
3336
answer = None
3437

3538
try:
@@ -38,44 +41,115 @@ def get(self, section, key):
3841
logging.error(f'Config file missing section: {section} - {err}')
3942

4043
return answer
44+
45+
def get_bool(self, section, key, default=False):
46+
try:
47+
return self.config.getboolean(section, key)
48+
except Exception as err:
49+
logging.error(f'Failed to parse boolean - returning default "False": {section} - {err}')
50+
return default
51+
52+
53+
class EmbeddingCache:
54+
def __init__(self):
55+
logger.info('Created in-memory cache')
56+
self.cache = {}
57+
58+
def get_cache_key(self, text, model_type):
59+
return hashlib.sha256((text + model_type).encode()).hexdigest()
60+
61+
def get(self, text, model_type):
62+
return self.cache.get(self.get_cache_key(text, model_type))
4163

64+
def set(self, text, model_type, embedding):
65+
self.cache[self.get_cache_key(text, model_type)] = embedding
4266

43-
def get_openai_embeddings(text: str) -> list:
44-
try:
45-
response = openai.Embedding.create(input=text, model='text-embedding-ada-002')
46-
return response['data'][0]['embedding']
47-
except Exception as err:
48-
logger.error(f'Failed to get OpenAI embeddings: {err}')
49-
abort(500, 'Failed to get OpenAI embeddings')
5067

68+
class EmbeddingGenerator:
69+
def __init__(self, sbert_model=None, openai_key=None):
70+
self.sbert_model = sbert_model
71+
if self.sbert_model is not None:
72+
try:
73+
self.model = SentenceTransformer(self.sbert_model)
74+
logger.info(f'enabled model: {self.sbert_model}')
75+
except Exception as err:
76+
logger.error(f'Failed to load SentenceTransformer model "{self.sbert_model}": {err}')
77+
sys.exit(1)
5178

52-
def get_transformers_embeddings(text: str) -> list:
53-
try:
54-
return model.encode(text).tolist()
55-
except Exception as err:
56-
logger.error(f'Failed to get sentence-transformers embeddings: {err}')
57-
abort(500, 'Failed to get sentence-transformers embeddings')
79+
if openai_key is not None:
80+
openai.api_key = openai_key
81+
logger.info('enabled model: text-embedding-ada-002')
82+
83+
def get_openai_embeddings(self, text):
84+
start_time = time.time()
85+
86+
try:
87+
response = openai.Embedding.create(input=text, model='text-embedding-ada-002')
88+
elapsed_time = (time.time() - start_time) * 1000
89+
data = {
90+
"embedding": response['data'][0]['embedding'],
91+
"status": "success",
92+
"elapsed": elapsed_time,
93+
"model": "text-embedding-ada-002"
94+
}
95+
return data
96+
except Exception as err:
97+
logger.error(f'Failed to get OpenAI embeddings: {err}')
98+
return {"status": "error", "message": str(err), "model": "text-embedding-ada-002"}
99+
100+
def get_transformers_embeddings(self, text):
101+
start_time = time.time()
102+
103+
try:
104+
embedding = self.model.encode(text).tolist()
105+
elapsed_time = (time.time() - start_time) * 1000
106+
data = {
107+
"embedding": embedding,
108+
"status": "success",
109+
"elapsed": elapsed_time,
110+
"model": self.sbert_model
111+
}
112+
return data
113+
except Exception as err:
114+
logger.error(f'Failed to get sentence-transformers embeddings: {err}')
115+
return {"status": "error", "message": str(err), "model": self.sbert_model}
116+
117+
def generate(self, text, model_type):
118+
if model_type == 'openai':
119+
return self.get_openai_embeddings(text)
120+
else:
121+
return self.get_transformers_embeddings(text)
58122

59123

60124
@app.route('/submit', methods=['POST'])
61125
def submit_text():
62126
data = request.json
63-
127+
64128
text_data = data.get('text')
65129
model_type = data.get('model', 'local').lower()
66130

67131
if text_data is None:
68132
abort(400, 'Missing text data to embed')
69-
133+
70134
if model_type not in ['local', 'openai']:
71135
abort(400, 'model field must be one of: local, openai')
72136

73-
if model_type == 'openai':
74-
embedding_data = get_openai_embeddings(text_data)
137+
if embedding_cache:
138+
result = embedding_cache.get(text_data, model_type)
139+
if result:
140+
logger.info('found embedding in cache!')
141+
result = {'embedding': result, 'cache': True, "status": 'success'}
75142
else:
76-
embedding_data = get_transformers_embeddings(text_data)
143+
result = None
144+
145+
if result is None:
146+
result = embedding_generator.generate(text_data, model_type)
147+
148+
if embedding_cache and result['status'] == 'success':
149+
embedding_cache.set(text_data, model_type, result['embedding'])
150+
logger.info('added to cache')
77151

78-
return jsonify({'embedding': embedding_data, 'status': 'success'})
152+
return jsonify(result)
79153

80154

81155
if __name__ == '__main__':
@@ -91,23 +165,21 @@ def submit_text():
91165
args = parser.parse_args()
92166

93167
conf = Config(args.config)
94-
api_key = conf.get('main', 'openai_api_key')
95-
sent_model = conf.get('main', 'sent_transformers_model')
168+
openai_key = conf.get_val('main', 'openai_api_key')
169+
sbert_model = conf.get_val('main', 'sent_transformers_model')
170+
use_cache = conf.get_bool('main', 'use_cache', default=False)
96171

97-
if api_key is None:
172+
if openai_key is None:
98173
logger.warn('No OpenAI API key set in configuration file: server.conf')
99-
else:
100-
logger.info('Set OpenAI API key via openai.api_key')
101-
openai.api_key = api_key
102174

103-
if sent_model is None:
175+
if sbert_model is None:
104176
logger.warn('No transformer model set in configuration file: server.conf')
105177

106-
try:
107-
model = SentenceTransformer(sent_model)
108-
except Exception as err:
109-
logger.error(f'Failed to load SentenceTransformer model "{sent_model}": {err}')
178+
if openai_key is None and sbert_model is None:
179+
logger.error('No sbert model set *and* no openAI key set; exiting')
110180
sys.exit(1)
111181

112-
app.run(debug=True)
182+
embedding_cache = EmbeddingCache() if use_cache else None
183+
embedding_generator = EmbeddingGenerator(sbert_model, openai_key)
113184

185+
app.run(debug=True)

0 commit comments

Comments
 (0)