Skip to content

Commit 0fa0426

Browse files
committed
feat: added chat with model
1 parent 7d60a2e commit 0fa0426

File tree

5 files changed

+101
-44
lines changed

5 files changed

+101
-44
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
venv/
33
__pycache__/
44
.env
5+
chat.db

server/api/handlers/profile.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,7 @@
22
from extensions import neo4j_db
33
from models import User, Project, Tag
44
from sqlalchemy.orm import sessionmaker
5-
from dotenv import load_dotenv
65
import os
7-
import google.generativeai as genai
8-
9-
load_dotenv()
10-
11-
# Configure Genai Key
12-
genai.configure(api_key=os.getenv("GENAI_API_KEY"))
13-
14-
def get_gemini_response(question, prompt):
15-
model = genai.GenerativeModel('gemini-pro')
16-
response = model.generate_content([prompt, question])
17-
return response.text
18-
19-
def analyze_project_description(description):
20-
prompt = (
21-
"You are a sophisticated AI model trained to categorize project descriptions into predefined domains. "
22-
"Please review the following project description and determine the most relevant domain from the provided options. "
23-
"The available domains are: healthcare, fintech, blockchain, sports, agriculture. Based on the description, "
24-
"choose the most appropriate domain from the given. If multiple domains are applicable, provide "
25-
"comma-separated tags.\n\n"
26-
"Project Description: {description}\n\n"
27-
"Domain(s) (comma-separated): "
28-
)
29-
question = description
30-
response = get_gemini_response(question, prompt)
31-
tags = extract_tags(response)
32-
return tags
33-
34-
def extract_tags(response_text):
35-
tags_part = response_text.split("Tags:")[-1].strip()
36-
tags_list = tags_part.split(',')
37-
tags = [tag.strip().lower() for tag in tags_list if tag.strip()]
38-
return ', '.join(tags)
396

407
def get_profile(username):
418
logged_in_user = request.args.get('logged_in_user')
@@ -130,11 +97,6 @@ def add_project(username):
13097
if project_record:
13198
project = project_record["p"]
13299

133-
# Analyze the description to get tags
134-
# tags = analyze_project_description(description)
135-
136-
# Categorize the tags into predefined domains
137-
# predefined_domains = ['healthcare', 'fintech', 'blockchain', 'sports', 'agriculture']
138100
domain_tags = [tag for tag in tags.split(',') if tag.strip() in tags]
139101

140102
# Update project with tags
@@ -169,11 +131,7 @@ def update_project(username, project_id):
169131
title = data.get('title')
170132
description = data.get('description')
171133
repo_link = data.get('repo_link')
172-
173-
if description:
174-
tags = analyze_project_description(description)
175-
else:
176-
tags = None
134+
tags = data.get('tags', '')
177135

178136
query = """
179137
MATCH (u:User {username: $username})-[:OWNS]->(p:Project {id: $project_id})
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from flask import Flask, request, jsonify
2+
from dotenv import load_dotenv
3+
import os
4+
from langchain_google_genai import ChatGoogleGenerativeAI
5+
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
6+
from langchain_community.graphs import Neo4jGraph
7+
import sqlite3
8+
import uuid
9+
10+
# Load environment variables
11+
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
12+
os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI")
13+
os.environ["NEO4J_USERNAME"] = os.getenv("NEO4J_USERNAME")
14+
os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD")
15+
16+
# Initialize Neo4j graph and language model
17+
graph = Neo4jGraph()
18+
llm = ChatGoogleGenerativeAI(model="gemini-pro")
19+
chain = GraphCypherQAChain.from_llm(graph=graph, llm=llm, verbose=True)
20+
21+
# Initialize SQLite database connection
22+
def init_db():
23+
conn = sqlite3.connect('chat.db')
24+
cursor = conn.cursor()
25+
cursor.execute('''
26+
CREATE TABLE IF NOT EXISTS chats (
27+
chat_id TEXT PRIMARY KEY,
28+
query TEXT NOT NULL,
29+
response TEXT NOT NULL
30+
)
31+
''')
32+
conn.commit()
33+
conn.close()
34+
35+
init_db()
36+
37+
def save_chat(chat_id, query, response):
38+
"""Save chat message to SQLite database."""
39+
conn = sqlite3.connect('chat.db')
40+
cursor = conn.cursor()
41+
cursor.execute('''
42+
INSERT INTO chats (chat_id, query, response) VALUES (?, ?, ?)
43+
''', (chat_id, query, response))
44+
conn.commit()
45+
conn.close()
46+
47+
def get_chat_by_id(chat_id):
48+
"""Retrieve chat query and response by chat ID from SQLite database."""
49+
conn = sqlite3.connect('chat.db')
50+
cursor = conn.cursor()
51+
cursor.execute('''
52+
SELECT query, response FROM chats WHERE chat_id = ?
53+
''', (chat_id,))
54+
result = cursor.fetchone()
55+
conn.close()
56+
return {"query": result[0], "response": result[1]} if result else None
57+
58+
59+
def get_graph_response(query):
60+
"""Get response from graph database using natural language query."""
61+
response = chain.invoke({"query": query})
62+
return response['result']
63+
64+
def chat():
65+
"""Chat route to handle user queries and return responses."""
66+
data = request.json
67+
query = data.get("query")
68+
69+
if not query:
70+
return jsonify({"error": "No query provided"}), 400
71+
72+
try:
73+
result = get_graph_response(query)
74+
chat_id = str(uuid.uuid4()) # Generate unique chat ID
75+
save_chat(chat_id, query, result)
76+
return jsonify({"chat_id": chat_id, "result": result}), 200
77+
except Exception as e:
78+
return jsonify({"error": "An error occurred while processing the query"}), 500
79+
80+
def retrieve_chat(chat_id):
81+
"""Retrieve chat response by chat ID."""
82+
try:
83+
result = get_chat_by_id(chat_id)
84+
if result:
85+
return jsonify({"chat_id": chat_id, "result": result}), 200
86+
else:
87+
return jsonify({"error": "Chat ID not found"}), 404
88+
except Exception as e:
89+
return jsonify({"error": "An error occurred while retrieving the chat"}), 500
90+

server/api/urls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from api.handlers.profile import get_profile, update_profile, add_project, update_project, delete_project
33
from api.handlers.analyze.githubdata import github_data, top_languages, streak_stats, pinned_repos, streak_chart
44
from api.handlers.analyze.leetcodedata import leetcode_data, leetcode_card
5+
from api.handlers.query.querymodel import chat,retrieve_chat
56
from api.handlers.friends import friends_bp
67

78
def register_routes(app):
@@ -35,6 +36,10 @@ def register_routes(app):
3536
app.add_url_rule('/profile/<username>/projects/<int:project_id>', 'update_project', update_project, methods=['PUT'])
3637
app.add_url_rule('/profile/<username>/projects/<string:project_title>', 'delete_project', delete_project, methods=['DELETE'])
3738

39+
# Chat with model routes
40+
app.add_url_rule('/chat','chat',chat, methods=['POST'])
41+
app.add_url_rule('/chat/<chat_id>','retrieve_chat',retrieve_chat, methods=['GET'])
42+
3843
# Landing page route
3944
app.add_url_rule('/', 'index', index)
4045

server/requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@ bcrypt
99
neo4j
1010
neo4j-driver
1111
google-generativeai
12-
beautifulsoup4
12+
beautifulsoup4
13+
langchain==0.2.1
14+
langchain-community==0.2.1
15+
langchain-google-genai==1.0.5

0 commit comments

Comments
 (0)