Skip to content

Commit 0573ce2

Browse files
2 parents 6171b59 + 2b230e3 commit 0573ce2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+597
-274
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ Allow unauthenticated request : Yes
149149
| VITE_GOOGLE_CLIENT_ID | Optional | | Client ID for Google authentication |
150150
| VITE_LLM_MODELS_PROD | Optional | openai_gpt_4o,openai_gpt_4o_mini,diffbot,gemini_1.5_flash | To Distinguish models based on the Enviornment PROD or DEV
151151
| VITE_LLM_MODELS | Optional | 'diffbot,openai_gpt_3.5,openai_gpt_4o,openai_gpt_4o_mini,gemini_1.5_pro,gemini_1.5_flash,azure_ai_gpt_35,azure_ai_gpt_4o,ollama_llama3,groq_llama3_70b,anthropic_claude_3_5_sonnet' | Supported Models For the application
152+
| VITE_AUTH0_CLIENT_ID | Mandatory if you are enabling Authentication otherwise it is optional | |Okta Oauth Client ID for authentication
153+
| VITE_AUTH0_DOMAIN | Mandatory if you are enabling Authentication otherwise it is optional | | Okta Oauth Cliend Domain
154+
| VITE_SKIP_AUTH | Optional | true | Flag to skip the authentication
152155

153156
## LLMs Supported
154157
1. OpenAI

backend/example.env

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,8 @@ LLM_MODEL_CONFIG_ollama_llama3="model_name,model_local_url"
4444
YOUTUBE_TRANSCRIPT_PROXY="https://user:pass@domain:port"
4545
EFFECTIVE_SEARCH_RATIO=5
4646
GRAPH_CLEANUP_MODEL="openai_gpt_4o"
47-
CHUNKS_TO_BE_PROCESSED="50"
47+
CHUNKS_TO_BE_CREATED="50"
48+
BEDROCK_EMBEDDING_MODEL="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.titan-embed-text-v1"
49+
LLM_MODEL_CONFIG_bedrock_nova_micro_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-micro-v1:0"
50+
LLM_MODEL_CONFIG_bedrock_nova_lite_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-lite-v1:0"
51+
LLM_MODEL_CONFIG_bedrock_nova_pro_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-pro-v1:0"

backend/score.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import FastAPI, File, UploadFile, Form, Request
1+
from fastapi import FastAPI, File, UploadFile, Form, Request, HTTPException
22
from fastapi_health import health
33
from fastapi.middleware.cors import CORSMiddleware
44
from src.main import *
@@ -19,7 +19,6 @@
1919
from src.neighbours import get_neighbour_nodes
2020
import json
2121
from typing import List
22-
from starlette.middleware.sessions import SessionMiddleware
2322
from google.oauth2.credentials import Credentials
2423
import os
2524
from src.logger import CustomLogger
@@ -33,6 +32,10 @@
3332
from starlette.types import ASGIApp, Receive, Scope, Send
3433
from langchain_neo4j import Neo4jGraph
3534
from src.entities.source_node import sourceNode
35+
from starlette.middleware.sessions import SessionMiddleware
36+
from starlette.responses import HTMLResponse, RedirectResponse,JSONResponse
37+
from starlette.requests import Request
38+
import secrets
3639

3740
logger = CustomLogger()
3841
CHUNK_DIR = os.path.join(os.path.dirname(__file__), "chunks")
@@ -77,6 +80,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
7780
)
7881
await gzip_middleware(scope, receive, send)
7982
app = FastAPI()
83+
8084
app.add_middleware(XContentTypeOptions)
8185
app.add_middleware(XFrame, Option={'X-Frame-Options': 'DENY'})
8286
app.add_middleware(CustomGZipMiddleware, minimum_size=1000, compresslevel=5,paths=["/sources_list","/url/scan","/extract","/chat_bot","/chunk_entities","/get_neighbours","/graph_query","/schema","/populate_graph_schema","/get_unconnected_nodes_list","/get_duplicate_nodes","/fetch_chunktext"])
@@ -86,14 +90,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
8690
allow_methods=["*"],
8791
allow_headers=["*"],
8892
)
93+
app.add_middleware(SessionMiddleware, secret_key=os.urandom(24))
8994

9095
is_gemini_enabled = os.environ.get("GEMINI_ENABLED", "False").lower() in ("true", "1", "yes")
9196
if is_gemini_enabled:
9297
add_routes(app,ChatVertexAI(), path="/vertexai")
9398

9499
app.add_api_route("/health", health([healthy_condition, healthy]))
95100

96-
app.add_middleware(SessionMiddleware, secret_key=os.urandom(24))
97101

98102

99103
@app.post("/url/scan")
@@ -346,14 +350,15 @@ async def post_processing(uri=Form(), userName=Form(), password=Form(), database
346350
await asyncio.to_thread(create_communities, uri, userName, password, database)
347351

348352
logging.info(f'created communities')
349-
graph = create_graph_database_connection(uri, userName, password, database)
350-
graphDb_data_Access = graphDBdataAccess(graph)
351-
document_name = ""
352-
count_response = graphDb_data_Access.update_node_relationship_count(document_name)
353-
if count_response:
354-
count_response = [{"filename": filename, **counts} for filename, counts in count_response.items()]
355-
logging.info(f'Updated source node with community related counts')
356-
353+
354+
355+
graph = create_graph_database_connection(uri, userName, password, database)
356+
graphDb_data_Access = graphDBdataAccess(graph)
357+
document_name = ""
358+
count_response = graphDb_data_Access.update_node_relationship_count(document_name)
359+
if count_response:
360+
count_response = [{"filename": filename, **counts} for filename, counts in count_response.items()]
361+
logging.info(f'Updated source node with community related counts')
357362

358363
end = time.time()
359364
elapsed_time = end - start
@@ -502,12 +507,14 @@ async def connect(uri=Form(), userName=Form(), password=Form(), database=Form())
502507
graph = create_graph_database_connection(uri, userName, password, database)
503508
result = await asyncio.to_thread(connection_check_and_get_vector_dimensions, graph, database)
504509
gcs_file_cache = os.environ.get('GCS_FILE_CACHE')
510+
chunk_to_be_created = int(os.environ.get('CHUNKS_TO_BE_CREATED', '50'))
505511
end = time.time()
506512
elapsed_time = end - start
507513
json_obj = {'api_name':'connect','db_url':uri, 'userName':userName, 'database':database, 'count':1, 'logging_time': formatted_time(datetime.now(timezone.utc)), 'elapsed_api_time':f'{elapsed_time:.2f}'}
508514
logger.log_struct(json_obj, "INFO")
509515
result['elapsed_api_time'] = f'{elapsed_time:.2f}'
510516
result['gcs_file_cache'] = gcs_file_cache
517+
result['chunk_to_be_created']= chunk_to_be_created
511518
return create_api_response('Success',data=result)
512519
except Exception as e:
513520
job_status = "Failed"
@@ -980,8 +987,8 @@ async def backend_connection_configuration():
980987
database= os.getenv('NEO4J_DATABASE')
981988
password= os.getenv('NEO4J_PASSWORD')
982989
gcs_file_cache = os.environ.get('GCS_FILE_CACHE')
990+
chunk_to_be_created = int(os.environ.get('CHUNKS_TO_BE_CREATED', '50'))
983991
if all([uri, username, database, password]):
984-
print(f'uri:{uri}, usrName:{username}, database :{database}, password: {password}')
985992
graph = Neo4jGraph()
986993
logging.info(f'login connection status of object: {graph}')
987994
if graph is not None:
@@ -995,6 +1002,7 @@ async def backend_connection_configuration():
9951002
result["database"] = database
9961003
result["password"] = encoded_password
9971004
result['gcs_file_cache'] = gcs_file_cache
1005+
result['chunk_to_be_created']= chunk_to_be_created
9981006
end = time.time()
9991007
elapsed_time = end - start
10001008
result['api_name'] = 'backend_connection_configuration'

backend/src/create_chunks.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from src.document_sources.youtube import get_chunks_with_timestamps, get_calculated_timestamps
66
import re
7+
import os
78

89
logging.basicConfig(format="%(asctime)s - %(message)s", level="INFO")
910

@@ -25,23 +26,28 @@ def split_file_into_chunks(self):
2526
"""
2627
logging.info("Split file into smaller chunks")
2728
text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20)
29+
chunk_to_be_created = int(os.environ.get('CHUNKS_TO_BE_CREATED', '50'))
2830
if 'page' in self.pages[0].metadata:
2931
chunks = []
3032
for i, document in enumerate(self.pages):
3133
page_number = i + 1
32-
for chunk in text_splitter.split_documents([document]):
33-
chunks.append(Document(page_content=chunk.page_content, metadata={'page_number':page_number}))
34+
if len(chunks) >= chunk_to_be_created:
35+
break
36+
else:
37+
for chunk in text_splitter.split_documents([document]):
38+
chunks.append(Document(page_content=chunk.page_content, metadata={'page_number':page_number}))
3439

3540
elif 'length' in self.pages[0].metadata:
3641
if len(self.pages) == 1 or (len(self.pages) > 1 and self.pages[1].page_content.strip() == ''):
3742
match = re.search(r'(?:v=)([0-9A-Za-z_-]{11})\s*',self.pages[0].metadata['source'])
3843
youtube_id=match.group(1)
3944
chunks_without_time_range = text_splitter.split_documents([self.pages[0]])
40-
chunks = get_calculated_timestamps(chunks_without_time_range, youtube_id)
41-
45+
chunks = get_calculated_timestamps(chunks_without_time_range[:chunk_to_be_created], youtube_id)
4246
else:
43-
chunks_without_time_range = text_splitter.split_documents(self.pages)
44-
chunks = get_chunks_with_timestamps(chunks_without_time_range)
47+
chunks_without_time_range = text_splitter.split_documents(self.pages)
48+
chunks = get_chunks_with_timestamps(chunks_without_time_range[:chunk_to_be_created])
4549
else:
4650
chunks = text_splitter.split_documents(self.pages)
51+
52+
chunks = chunks[:chunk_to_be_created]
4753
return chunks

backend/src/graphDB_dataAccess.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,4 +535,30 @@ def update_node_relationship_count(self,document_name):
535535
"nodeCount" : nodeCount,
536536
"relationshipCount" : relationshipCount
537537
}
538-
return response
538+
return response
539+
540+
def get_nodelabels_relationships(self):
541+
node_query = """
542+
CALL db.labels() YIELD label
543+
WITH label
544+
WHERE NOT label IN ['Document', 'Chunk', '_Bloom_Perspective_', '__Community__', '__Entity__']
545+
CALL apoc.cypher.run("MATCH (n:`" + label + "`) RETURN count(n) AS count",{}) YIELD value
546+
WHERE value.count > 0
547+
RETURN label order by label
548+
"""
549+
550+
relation_query = """
551+
CALL db.relationshipTypes() yield relationshipType
552+
WHERE NOT relationshipType IN ['PART_OF', 'NEXT_CHUNK', 'HAS_ENTITY', '_Bloom_Perspective_','FIRST_CHUNK','SIMILAR','IN_COMMUNITY','PARENT_COMMUNITY']
553+
return relationshipType order by relationshipType
554+
"""
555+
556+
try:
557+
node_result = self.execute_query(node_query)
558+
node_labels = [record["label"] for record in node_result]
559+
relationship_result = self.execute_query(relation_query)
560+
relationship_types = [record["relationshipType"] for record in relationship_result]
561+
return node_labels,relationship_types
562+
except Exception as e:
563+
print(f"Error in getting node labels/relationship types from db: {e}")
564+
return []

backend/src/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_llm(model: str):
8989
)
9090

9191
llm = ChatBedrock(
92-
client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0)
92+
client=bedrock_client,region_name=region_name, model_id=model_name, model_kwargs=dict(temperature=0)
9393
)
9494

9595
elif "ollama" in model:

backend/src/main.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,12 @@ async def processing_source(uri, userName, password, database, model, file_name,
361361

362362
logging.info('Update the status as Processing')
363363
update_graph_chunk_processed = int(os.environ.get('UPDATE_GRAPH_CHUNKS_PROCESSED'))
364-
chunk_to_be_processed = int(os.environ.get('CHUNKS_TO_BE_PROCESSED', '50'))
364+
365365
# selected_chunks = []
366366
is_cancelled_status = False
367367
job_status = "Completed"
368368
for i in range(0, len(chunkId_chunkDoc_list), update_graph_chunk_processed):
369369
select_chunks_upto = i+update_graph_chunk_processed
370-
if select_chunks_upto > chunk_to_be_processed:
371-
break
372370
logging.info(f'Selected Chunks upto: {select_chunks_upto}')
373371
if len(chunkId_chunkDoc_list) <= select_chunks_upto:
374372
select_chunks_upto = len(chunkId_chunkDoc_list)
@@ -676,7 +674,7 @@ def get_labels_and_relationtypes(graph):
676674
query = """
677675
RETURN collect {
678676
CALL db.labels() yield label
679-
WHERE NOT label IN ['Chunk','_Bloom_Perspective_', '__Community__', '__Entity__']
677+
WHERE NOT label IN ['Document','Chunk','_Bloom_Perspective_', '__Community__', '__Entity__']
680678
return label order by label limit 100 } as labels,
681679
collect {
682680
CALL db.relationshipTypes() yield relationshipType as type

backend/src/post_processing.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from langchain_core.prompts import ChatPromptTemplate
99
from src.shared.constants import GRAPH_CLEANUP_PROMPT
1010
from src.llm import get_llm
11-
from src.main import get_labels_and_relationtypes
11+
from src.graphDB_dataAccess import graphDBdataAccess
12+
import time
13+
1214

1315
DROP_INDEX_QUERY = "DROP INDEX entities IF EXISTS;"
1416
LABELS_QUERY = "CALL db.labels()"
@@ -195,50 +197,35 @@ def update_embeddings(rows, graph):
195197
return graph.query(query,params={'rows':rows})
196198

197199
def graph_schema_consolidation(graph):
198-
nodes_and_relations = get_labels_and_relationtypes(graph)
199-
logging.info(f"nodes_and_relations in existing graph : {nodes_and_relations}")
200-
node_labels = []
201-
relation_labels = []
202-
203-
node_labels.extend(nodes_and_relations[0]['labels'])
204-
relation_labels.extend(nodes_and_relations[0]['relationshipTypes'])
205-
200+
graphDb_data_Access = graphDBdataAccess(graph)
201+
node_labels,relation_labels = graphDb_data_Access.get_nodelabels_relationships()
206202
parser = JsonOutputParser()
207-
prompt = ChatPromptTemplate(messages=[("system",GRAPH_CLEANUP_PROMPT),("human", "{input}")],
208-
partial_variables={"format_instructions": parser.get_format_instructions()})
209-
210-
graph_cleanup_model = os.getenv("GRAPH_CLEANUP_MODEL",'openai_gpt_4o')
203+
prompt = ChatPromptTemplate(
204+
messages=[("system", GRAPH_CLEANUP_PROMPT), ("human", "{input}")],
205+
partial_variables={"format_instructions": parser.get_format_instructions()}
206+
)
207+
graph_cleanup_model = os.getenv("GRAPH_CLEANUP_MODEL", 'openai_gpt_4o')
211208
llm, _ = get_llm(graph_cleanup_model)
212209
chain = prompt | llm | parser
213-
nodes_dict = chain.invoke({'input':node_labels})
214-
relation_dict = chain.invoke({'input':relation_labels})
215-
216-
node_match = {}
217-
relation_match = {}
218-
for new_label , values in nodes_dict.items() :
219-
for old_label in values:
220-
if new_label != old_label:
221-
node_match[old_label]=new_label
222-
223-
for new_label , values in relation_dict.items() :
224-
for old_label in values:
225-
if new_label != old_label:
226-
relation_match[old_label]=new_label
227-
228-
logging.info(f"updated node labels : {node_match}")
229-
logging.info(f"updated relationship labels : {relation_match}")
230210

231-
# Update node labels in graph
232-
for old_label, new_label in node_match.items():
233-
query = f"""
234-
MATCH (n:`{old_label}`)
235-
SET n:`{new_label}`
236-
REMOVE n:`{old_label}`
237-
"""
238-
graph.query(query)
211+
nodes_relations_input = {'nodes': node_labels, 'relationships': relation_labels}
212+
mappings = chain.invoke({'input': nodes_relations_input})
213+
node_mapping = {old: new for new, old_list in mappings['nodes'].items() for old in old_list if new != old}
214+
relation_mapping = {old: new for new, old_list in mappings['relationships'].items() for old in old_list if new != old}
215+
216+
logging.info(f"Node Labels: Total = {len(node_labels)}, Reduced to = {len(set(node_mapping.values()))} (from {len(node_mapping)})")
217+
logging.info(f"Relationship Types: Total = {len(relation_labels)}, Reduced to = {len(set(relation_mapping.values()))} (from {len(relation_mapping)})")
218+
219+
if node_mapping:
220+
for old_label, new_label in node_mapping.items():
221+
query = f"""
222+
MATCH (n:`{old_label}`)
223+
SET n:`{new_label}`
224+
REMOVE n:`{old_label}`
225+
"""
226+
graph.query(query)
239227

240-
# Update relation types in graph
241-
for old_label, new_label in relation_match.items():
228+
for old_label, new_label in relation_mapping.items():
242229
query = f"""
243230
MATCH (n)-[r:`{old_label}`]->(m)
244231
CREATE (n)-[r2:`{new_label}`]->(m)

backend/src/shared/common_fn.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import os
1212
from pathlib import Path
1313
from urllib.parse import urlparse
14-
14+
import boto3
15+
from langchain_community.embeddings import BedrockEmbeddings
1516

1617
def check_url_source(source_type, yt_url:str=None, wiki_query:str=None):
1718
language=''
@@ -77,6 +78,10 @@ def load_embedding_model(embedding_model_name: str):
7778
)
7879
dimension = 768
7980
logging.info(f"Embedding: Using Vertex AI Embeddings , Dimension:{dimension}")
81+
elif embedding_model_name == "titan":
82+
embeddings = get_bedrock_embeddings()
83+
dimension = 1536
84+
logging.info(f"Embedding: Using bedrock titan Embeddings , Dimension:{dimension}")
8085
else:
8186
embeddings = HuggingFaceEmbeddings(
8287
model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
@@ -134,4 +139,38 @@ def last_url_segment(url):
134139
parsed_url = urlparse(url)
135140
path = parsed_url.path.strip("/") # Remove leading and trailing slashes
136141
last_url_segment = path.split("/")[-1] if path else parsed_url.netloc.split(".")[0]
137-
return last_url_segment
142+
return last_url_segment
143+
144+
def get_bedrock_embeddings():
145+
"""
146+
Creates and returns a BedrockEmbeddings object using the specified model name.
147+
Args:
148+
model (str): The name of the model to use for embeddings.
149+
Returns:
150+
BedrockEmbeddings: An instance of the BedrockEmbeddings class.
151+
"""
152+
try:
153+
env_value = os.getenv("BEDROCK_EMBEDDING_MODEL")
154+
if not env_value:
155+
raise ValueError("Environment variable 'BEDROCK_EMBEDDING_MODEL' is not set.")
156+
try:
157+
model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",")
158+
except ValueError:
159+
raise ValueError(
160+
"Environment variable 'BEDROCK_EMBEDDING_MODEL' is improperly formatted. "
161+
"Expected format: 'model_name,aws_access_key,aws_secret_key,region_name'."
162+
)
163+
bedrock_client = boto3.client(
164+
service_name="bedrock-runtime",
165+
region_name=region_name.strip(),
166+
aws_access_key_id=aws_access_key.strip(),
167+
aws_secret_access_key=aws_secret_key.strip(),
168+
)
169+
bedrock_embeddings = BedrockEmbeddings(
170+
model_id=model_name.strip(),
171+
client=bedrock_client
172+
)
173+
return bedrock_embeddings
174+
except Exception as e:
175+
print(f"An unexpected error occurred: {e}")
176+
raise

0 commit comments

Comments
 (0)