Skip to content

Commit 17e3b52

Browse files
sentenceTransformar embedding model download locally to use (#1361)
* sentenceTransformar embedding model download locally to use * Added command in docker to download the embedding model once docker up * FIx the docker issue * download nltk punkt from docker * env fetching changes * Added library in docker * Remove Duplicate import --------- Co-authored-by: kaustubh-darekar <[email protected]>
1 parent 2372a68 commit 17e3b52

File tree

8 files changed

+117
-54
lines changed

8 files changed

+117
-54
lines changed

backend/Dockerfile

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,33 @@ EXPOSE 8000
66
RUN apt-get update && \
77
apt-get install -y --no-install-recommends \
88
libmagic1 \
9-
libgl1-mesa-glx \
9+
libgl1 \
10+
libglx-mesa0 \
1011
libreoffice \
1112
cmake \
1213
poppler-utils \
1314
tesseract-ocr && \
1415
apt-get clean && \
1516
rm -rf /var/lib/apt/lists/*
17+
1618
# Set LD_LIBRARY_PATH
1719
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
1820
# Copy requirements file and install Python dependencies
1921
COPY requirements.txt constraints.txt /code/
2022
# --no-cache-dir --upgrade
2123
RUN pip install --upgrade pip
2224
RUN pip install -r requirements.txt -c constraints.txt
25+
26+
RUN python -c "from transformers import AutoTokenizer, AutoModel; \
27+
name='sentence-transformers/all-MiniLM-L6-v2'; \
28+
tok=AutoTokenizer.from_pretrained(name); \
29+
mod=AutoModel.from_pretrained(name); \
30+
tok.save_pretrained('./local_model'); \
31+
mod.save_pretrained('./local_model')"
32+
33+
RUN python -m nltk.downloader -d /usr/local/nltk_data punkt
34+
RUN python -m nltk.downloader -d /usr/local/nltk_data averaged_perceptron_tagger
35+
2336
# Copy application code
2437
COPY . /code
2538
# Set command

backend/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ wrapt==1.17.2
5353
yarl==1.20.1
5454
youtube-transcript-api==1.1.0
5555
zipp==3.23.0
56-
sentence-transformers==4.1.0
56+
sentence-transformers==5.0.0
5757
google-cloud-logging==3.12.1
5858
pypandoc==1.15
5959
graphdatascience==1.15.1
6060
Secweb==1.18.1
61-
ragas==0.2.15
61+
ragas==0.3.1
6262
rouge_score==0.1.2
6363
langchain-neo4j==0.4.0
6464
pypandoc-binary==1.15

backend/src/QA_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
load_dotenv()
3939

4040
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
41-
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
4241

4342
class SessionChatHistory:
4443
history_dict = {}
@@ -304,6 +303,7 @@ def create_document_retriever_chain(llm, retriever):
304303
output_parser = StrOutputParser()
305304

306305
splitter = TokenTextSplitter(chunk_size=CHAT_DOC_SPLIT_SIZE, chunk_overlap=0)
306+
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
307307
embeddings_filter = EmbeddingsFilter(
308308
embeddings=EMBEDDING_FUNCTION,
309309
similarity_threshold=CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD
@@ -344,7 +344,7 @@ def initialize_neo4j_vector(graph, chat_mode_settings):
344344

345345
if not retrieval_query or not index_name:
346346
raise ValueError("Required settings 'retrieval_query' or 'index_name' are missing.")
347-
347+
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
348348
if keyword_index:
349349
neo_db = Neo4jVector.from_existing_graph(
350350
embedding=EMBEDDING_FUNCTION,

backend/src/document_sources/gcs_bucket.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,46 +46,58 @@ def gcs_loader_func(file_path):
4646
return loader
4747

4848
def get_documents_from_gcs(gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token=None):
49-
nltk.download('punkt')
50-
nltk.download('averaged_perceptron_tagger')
51-
if gcs_bucket_folder is not None and gcs_bucket_folder.strip()!="":
52-
if gcs_bucket_folder.endswith('/'):
53-
blob_name = gcs_bucket_folder+gcs_blob_filename
49+
50+
nltk.data.path.append("/usr/local/nltk_data")
51+
nltk.data.path.append(os.path.expanduser("~/.nltk_data"))
52+
try:
53+
nltk.data.find("tokenizers/punkt")
54+
except LookupError:
55+
for resource in ["punkt", "averaged_perceptron_tagger"]:
56+
try:
57+
nltk.data.find(f"tokenizers/{resource}" if resource == "punkt" else f"taggers/{resource}")
58+
except LookupError:
59+
logging.info(f"Downloading NLTK resource: {resource}")
60+
nltk.download(resource, download_dir=os.path.expanduser("~/.nltk_data"))
61+
62+
logging.info("NLTK resources downloaded successfully.")
63+
if gcs_bucket_folder is not None and gcs_bucket_folder.strip()!="":
64+
if gcs_bucket_folder.endswith('/'):
65+
blob_name = gcs_bucket_folder+gcs_blob_filename
66+
else:
67+
blob_name = gcs_bucket_folder+'/'+gcs_blob_filename
5468
else:
55-
blob_name = gcs_bucket_folder+'/'+gcs_blob_filename
56-
else:
57-
blob_name = gcs_blob_filename
58-
59-
logging.info(f"GCS project_id : {gcs_project_id}")
60-
61-
if access_token is None:
62-
storage_client = storage.Client(project=gcs_project_id)
63-
bucket = storage_client.bucket(gcs_bucket_name)
64-
blob = bucket.blob(blob_name)
69+
blob_name = gcs_blob_filename
6570

66-
if blob.exists():
67-
loader = GCSFileLoader(project_name=gcs_project_id, bucket=gcs_bucket_name, blob=blob_name, loader_func=gcs_loader_func)
68-
pages = loader.load()
69-
else :
70-
raise LLMGraphBuilderException('File does not exist, Please re-upload the file and try again.')
71-
else:
72-
creds= Credentials(access_token)
73-
storage_client = storage.Client(project=gcs_project_id, credentials=creds)
71+
logging.info(f"GCS project_id : {gcs_project_id}")
7472

75-
bucket = storage_client.bucket(gcs_bucket_name)
76-
blob = bucket.blob(blob_name)
77-
if blob.exists():
78-
content = blob.download_as_bytes()
79-
pdf_file = io.BytesIO(content)
80-
pdf_reader = PdfReader(pdf_file)
81-
# Extract text from all pages
82-
text = ""
83-
for page in pdf_reader.pages:
84-
text += page.extract_text()
85-
pages = [Document(page_content = text)]
73+
if access_token is None:
74+
storage_client = storage.Client(project=gcs_project_id)
75+
bucket = storage_client.bucket(gcs_bucket_name)
76+
blob = bucket.blob(blob_name)
77+
78+
if blob.exists():
79+
loader = GCSFileLoader(project_name=gcs_project_id, bucket=gcs_bucket_name, blob=blob_name, loader_func=gcs_loader_func)
80+
pages = loader.load()
81+
else :
82+
raise LLMGraphBuilderException('File does not exist, Please re-upload the file and try again.')
8683
else:
87-
raise LLMGraphBuilderException(f'File Not Found in GCS bucket - {gcs_bucket_name}')
88-
return gcs_blob_filename, pages
84+
creds= Credentials(access_token)
85+
storage_client = storage.Client(project=gcs_project_id, credentials=creds)
86+
87+
bucket = storage_client.bucket(gcs_bucket_name)
88+
blob = bucket.blob(blob_name)
89+
if blob.exists():
90+
content = blob.download_as_bytes()
91+
pdf_file = io.BytesIO(content)
92+
pdf_reader = PdfReader(pdf_file)
93+
# Extract text from all pages
94+
text = ""
95+
for page in pdf_reader.pages:
96+
text += page.extract_text()
97+
pages = [Document(page_content = text)]
98+
else:
99+
raise LLMGraphBuilderException(f'File Not Found in GCS bucket - {gcs_bucket_name}')
100+
return gcs_blob_filename, pages
89101

90102
def upload_file_to_gcs(file_chunk, chunk_number, original_file_name, bucket_name, folder_name_sha1_hashed):
91103
try:

backend/src/make_relationships.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
logging.basicConfig(format='%(asctime)s - %(message)s',level='INFO')
1313

1414
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
15-
EMBEDDING_FUNCTION , EMBEDDING_DIMENSION = load_embedding_model(EMBEDDING_MODEL)
1615

1716
def merge_relationship_between_chunk_and_entites(graph: Neo4jGraph, graph_documents_chunk_chunk_Id : list):
1817
batch_data = []
@@ -41,7 +40,7 @@ def merge_relationship_between_chunk_and_entites(graph: Neo4jGraph, graph_docume
4140
def create_chunk_embeddings(graph, chunkId_chunkDoc_list, file_name):
4241
isEmbedding = os.getenv('IS_EMBEDDING')
4342

44-
embeddings, dimension = EMBEDDING_FUNCTION , EMBEDDING_DIMENSION
43+
embeddings, dimension = load_embedding_model(EMBEDDING_MODEL)
4544
logging.info(f'embedding model:{embeddings} and dimesion:{dimension}')
4645
data_for_query = []
4746
logging.info(f"update embedding and vector index for chunks")
@@ -161,6 +160,7 @@ def create_chunk_vector_index(graph):
161160
vector_index_query = "SHOW INDEXES YIELD name, type, labelsOrTypes, properties WHERE name = 'vector' AND type = 'VECTOR' AND 'Chunk' IN labelsOrTypes AND 'embedding' IN properties RETURN name"
162161
vector_index = execute_graph_query(graph,vector_index_query)
163162
if not vector_index:
163+
EMBEDDING_FUNCTION , EMBEDDING_DIMENSION = load_embedding_model(EMBEDDING_MODEL)
164164
vector_store = Neo4jVector(embedding=EMBEDDING_FUNCTION,
165165
graph=graph,
166166
node_label="Chunk",

backend/src/ragas_eval.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from ragas.embeddings import LangchainEmbeddingsWrapper
1414
import nltk
1515

16-
nltk.download('punkt')
16+
nltk.data.path.append("/usr/local/nltk_data")
17+
nltk.data.path.append(os.path.expanduser("~/.nltk_data"))
18+
try:
19+
nltk.data.find("tokenizers/punkt")
20+
except LookupError:
21+
nltk.download("punkt", download_dir=os.path.expanduser("~/.nltk_data"))
22+
1723
load_dotenv()
1824

1925
EMBEDDING_MODEL = os.getenv("RAGAS_EMBEDDING_MODEL")

backend/src/shared/common_fn.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import hashlib
2+
import os
3+
from transformers import AutoTokenizer, AutoModel
4+
from langchain_huggingface import HuggingFaceEmbeddings
5+
from threading import Lock
26
import logging
37
from src.document_sources.youtube import create_youtube_url
4-
from langchain_huggingface import HuggingFaceEmbeddings
58
from langchain_google_vertexai import VertexAIEmbeddings
69
from langchain_openai import OpenAIEmbeddings
710
from langchain_neo4j import Neo4jGraph
@@ -16,6 +19,40 @@
1619
import boto3
1720
from langchain_community.embeddings import BedrockEmbeddings
1821

22+
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
23+
MODEL_PATH = "./local_model"
24+
_lock = Lock()
25+
_embedding_instance = None
26+
27+
def ensure_sentence_transformer_model_downloaded():
28+
if os.path.isdir(MODEL_PATH):
29+
print("Model already downloaded at:", MODEL_PATH)
30+
return
31+
else:
32+
print("Downloading model to:", MODEL_PATH)
33+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
34+
model = AutoModel.from_pretrained(MODEL_NAME)
35+
tokenizer.save_pretrained(MODEL_PATH)
36+
model.save_pretrained(MODEL_PATH)
37+
print("Model downloaded and saved.")
38+
39+
def get_local_sentence_transformer_embedding():
40+
"""
41+
Lazy, threadsafe singleton. Caller does not need to worry about
42+
import-time initialization or download race.
43+
"""
44+
global _embedding_instance
45+
if _embedding_instance is not None:
46+
return _embedding_instance
47+
with _lock:
48+
if _embedding_instance is not None:
49+
return _embedding_instance
50+
# Ensure model is present before instantiating
51+
ensure_sentence_transformer_model_downloaded()
52+
_embedding_instance = HuggingFaceEmbeddings(model_name=MODEL_PATH)
53+
print("Embedding model initialized.")
54+
return _embedding_instance
55+
1956
def check_url_source(source_type, yt_url:str=None, wiki_query:str=None):
2057
language=''
2158
try:
@@ -85,9 +122,8 @@ def load_embedding_model(embedding_model_name: str):
85122
dimension = 1536
86123
logging.info(f"Embedding: Using bedrock titan Embeddings , Dimension:{dimension}")
87124
else:
88-
embeddings = HuggingFaceEmbeddings(
89-
model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
90-
)
125+
# embeddings = HuggingFaceEmbeddings(model_name="./local_model")
126+
embeddings = get_local_sentence_transformer_embedding()
91127
dimension = 384
92128
logging.info(f"Embedding: Using Langchain HuggingFaceEmbeddings , Dimension:{dimension}")
93129
return embeddings, dimension

docker-compose.yml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@ services:
77
dockerfile: Dockerfile
88
volumes:
99
- ./backend:/code
10+
env_file:
11+
- ./backend/.env
1012
environment:
11-
- NEO4J_URI=${NEO4J_URI-neo4j://database:7687}
12-
- NEO4J_PASSWORD=${NEO4J_PASSWORD-password}
13-
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
14-
- OPENAI_API_KEY=${OPENAI_API_KEY-}
15-
- DIFFBOT_API_KEY=${DIFFBOT_API_KEY-}
16-
- EMBEDDING_MODEL=${EMBEDDING_MODEL-all-MiniLM-L6-v2}
1713
- LANGCHAIN_ENDPOINT=${LANGCHAIN_ENDPOINT-}
1814
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-}
1915
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT-}

0 commit comments

Comments
 (0)