Skip to content

Commit 9cef63c

Browse files
load the whole rag system during build and cache embeddings
1 parent 6fd9ab6 commit 9cef63c

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

app/Dockerfile

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,12 @@ RUN pip install --no-cache-dir -r requirements.txt
3636
# Set the environment variable for the sentence transformers model
3737
ENV SENTENCE_TRANSFORMERS_HOME="/root/.cache/sentence_transformers"
3838

39-
COPY ./embeddings.py /app/embeddings.py
40-
41-
# Preload the sentence transformer model to cache
42-
RUN python embeddings.py
43-
4439
# Copy the application source code into the container
4540
COPY . /app
4641

42+
# Preload the sentence transformer model to cache
43+
RUN python rag_system.py
44+
4745
# Expose port 5050 for the Flask application
4846
EXPOSE 5050
4947

app/embeddings.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

app/rag_system.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import json
33
import os
44
import sys
5+
import logging
56
from datetime import date
67
from sentence_transformers import SentenceTransformer
78
import numpy as np
89
from sklearn.metrics.pairwise import cosine_similarity
9-
from embeddings import load_model
1010
import traceback
1111

1212
openai.api_base = os.getenv("OPENAI_BASE_URL")
@@ -15,9 +15,21 @@
1515
class RAGSystem:
1616
def __init__(self, knowledge_base_path='./data/knowledge_base.json'):
1717
self.knowledge_base_path = knowledge_base_path
18+
1819
self.knowledge_base = self.load_knowledge_base()
19-
self.model = load_model()
20-
self.doc_embeddings = self.embed_knowledge_base()
20+
self.model = SentenceTransformer("all-MiniLM-L6-v2")
21+
22+
# load existing embeddings if available
23+
logging.info("Embedding knowledge base...")
24+
if os.path.exists('./data/doc_embeddings.npy'):
25+
self.doc_embeddings = np.load('./data/doc_embeddings.npy')
26+
logging.info("Loaded existing document embeddings from disk.")
27+
else:
28+
logging.info("No existing document embeddings found, creating new embeddings.")
29+
self.doc_embeddings = self.embed_knowledge_base()
30+
# cache doc_embeddings to disk
31+
np.save('./data/doc_embeddings.npy', self.doc_embeddings.cpu().numpy())
32+
logging.info("Knowledge base embeddings created")
2133
self.conversation_history = []
2234

2335
def load_knowledge_base(self):
@@ -38,7 +50,7 @@ def get_query_embedding(self, query):
3850
return query_embedding
3951

4052
def get_doc_embeddings(self):
41-
return self.doc_embeddings.cpu()
53+
return self.doc_embeddings
4254

4355
def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
4456
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
@@ -188,3 +200,11 @@ def get_context(self, retrieved_docs):
188200
for doc in retrieved_docs:
189201
retrieved_text.append(f"{doc['about']}. {doc['text']}")
190202
return "\n\n".join(retrieved_text)
203+
204+
if __name__ == "__main__":
205+
logging.basicConfig(
206+
level=logging.INFO,
207+
format="%(asctime)s %(levelname)s %(message)s",
208+
datefmt="%Y-%m-%d %H:%M:%S"
209+
)
210+
RAGSystem()

0 commit comments

Comments
 (0)