-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
96 lines (79 loc) · 3.01 KB
/
app.py
File metadata and controls
96 lines (79 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from dotenv import load_dotenv
import ast
import os
from flask import Flask, request, jsonify
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
app = Flask(__name__)
load_dotenv()
# Database setup
db = SQLDatabase.from_uri(
os.getenv('SQL_DATABASE_URI')
)
def query_as_list(db: SQLDatabase, query: str):
"""
Run the given query on the database and return the results as a list.
"""
res = str(db.run(query))
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [string.strip() for string in res]
return list(set(res))
# Initialize LLM and agents
llm = ChatOpenAI(
model="gpt-3.5-turbo",
temperature=0,
)
system_prompt = "You are a helpful SQL assistant. Answer questions about the database clearly and concisely.DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."
agent_executor = create_sql_agent(
llm,
db=db,
agent_type="openai-tools",
verbose=True,
system_message=system_prompt,
)
# Create general conversation chain
general_template = """You are a helpful AI assistant. Please respond to the user's question:
Question: {question}
Answer: """
general_prompt = PromptTemplate(template=general_template, input_variables=["question"])
general_chain = general_prompt | llm
# Get products and categories from database
products = query_as_list(db, "SELECT product_name FROM inventory_products")
categories = query_as_list(db, "SELECT name FROM categories")
def is_database_question(question):
db_keywords = [
'database', 'table', 'sql', 'query', 'inventory', 'inventory_products',
'record', 'data', 'select', 'show', 'product', 'category', 'products',
'categories', 'list', 'count', 'how many', 'find', 'search', 'where', 'report'
]
db_keywords.extend(products)
db_keywords.extend(categories)
question_lower = question.lower()
question_words = set(question_lower.split())
return any(keyword.lower() in question_words for keyword in db_keywords)
@app.route('/ask', methods=['POST'])
def ask_question():
try:
data = request.get_json()
if not data or 'question' not in data:
return jsonify({'error': 'No question provided'}), 400
user_question = data['question']
if is_database_question(user_question):
response = agent_executor.invoke({"input": user_question})
return jsonify({
'type': 'database',
'response': response["output"]
})
else:
response = general_chain.invoke({"question": user_question})
return jsonify({
'type': 'general',
'response': response.content
})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
from waitress import serve
serve(app, host='0.0.0.0', port=5000)