-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
238 lines (187 loc) · 8.47 KB
/
main.py
File metadata and controls
238 lines (187 loc) · 8.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
import logging
import time
import uuid
import json
from typing import Optional, Dict, Any, Union, Tuple, List
from flask import Flask, request, jsonify, g, Response
from openai import OpenAI
from dotenv import load_dotenv
from piggy_bank.db import init_db
from piggy_bank.tools import get_tools, run_tools
from piggy_bank.services import get_accounts
import sqlite3
load_dotenv()
DB_FILE: str = "pigbank.db"
app: Flask = Flask(__name__)
app.config["DATABASE"] = DB_FILE
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
client = OpenAI(api_key=os.environ.get("OPEN_AI_KEY"))
SESSION_TIMEOUT_SECONDS: int = 60
def get_db() -> sqlite3.Connection:
"""Get a database connection for the current request."""
if "db" not in g:
g.db = init_db(app.config["DATABASE"])
return g.db # type: ignore
@app.teardown_appcontext
def close_db(exception: Optional[BaseException]) -> None:
db = g.pop("db", None)
if db:
db.close()
# ------------------ AUTH ------------------
def get_subscription_id_from_token(token: str) -> Optional[int]:
db = get_db()
cur = db.execute("SELECT id FROM subscriptions WHERE auth_token = ?", (token,))
row = cur.fetchone()
return row["id"] if row else None
@app.before_request
def check_auth() -> Optional[Tuple[Response, int]]:
auth_header: str = request.headers.get("Authorization", "") # type: ignore
if not auth_header.startswith("Bearer "):
return jsonify({"error": "Authorization header missing."}), 401
token = auth_header.replace("Bearer ", "", 1)
subscription_id = get_subscription_id_from_token(token)
if not subscription_id:
return jsonify({"error": "Invalid subscription token."}), 401
g.subscription_id = subscription_id
return None
# ------------------ OPENAI HELPERS ------------------
def get_or_create_session(
session_id: Optional[str], now: float, subscription_id: int
) -> Tuple[str, List[Dict[str, Any]]]:
"""Get existing session or create a new one, returning session_id and messages."""
db = get_db()
# If session_id is provided and exists, use it
if session_id:
log.info("Looking for existing session: %s", session_id)
cur = db.execute(
"SELECT messages, last_access FROM sessions WHERE id = ? AND subscription_id = ?",
(session_id, subscription_id),
)
row = cur.fetchone()
if row and now - row["last_access"] < SESSION_TIMEOUT_SECONDS:
# Update last access time
db.execute("UPDATE sessions SET last_access = ? WHERE id = ?", (now, session_id))
db.commit()
log.info("Using existing session: %s", session_id)
return session_id, json.loads(row["messages"])
elif row:
# Session has expired, delete it
db.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
db.commit()
log.info("Session %s expired, deleted", session_id)
# Try to find the most recent active session for this subscription
cur = db.execute(
"SELECT id, messages, last_access FROM sessions WHERE subscription_id = ? ORDER BY last_access DESC LIMIT 1",
(subscription_id,),
)
row = cur.fetchone()
if row and now - row["last_access"] < SESSION_TIMEOUT_SECONDS:
# Update last access time
db.execute("UPDATE sessions SET last_access = ? WHERE id = ?", (now, row["id"]))
db.commit()
log.info("Resuming most recent session: %s", row["id"])
return row["id"], json.loads(row["messages"])
# Create a new session
new_session_id = str(uuid.uuid4())
log.info("Creating new session: %s", new_session_id)
# Get current accounts for this subscription
accounts_result = get_accounts(db, subscription_id)
accounts_info = ""
if accounts_result["error"] is None and accounts_result["response"]["accounts"]:
accounts_list = []
for account in accounts_result["response"]["accounts"]:
accounts_list.append(f"- {account['name']}: ${account['balance']:.2f} (ID: {account['id']})")
accounts_info = "\n\nCurrent accounts:\n" + "\n".join(accounts_list)
else:
accounts_info = "\n\nNo accounts currently exist."
messages = [
{
"role": "system",
"content": f"""You are a helpful piggy bank assistant that immediately executes all requested operations.
You have access to a set of tools to manage accounts, and you can call multiple tools in parallel.
IMPORTANT: Do NOT ask for verification or user approval - execute all requested actions immediately.
When given multi-step commands (like "add money and transfer"), execute ALL steps in the same response.
Do NOT plan or explain what you will do - just do it immediately.
If money was added, transferred or withdrawn, you should always return the updated balance after all operations are complete.{accounts_info}""",
}
]
db.execute(
"INSERT INTO sessions (id, subscription_id, messages, last_access) VALUES (?, ?, ?, ?)",
(new_session_id, subscription_id, json.dumps(messages), now),
)
db.commit()
return new_session_id, messages
def update_session_messages(session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Update the messages for a session in the database."""
db = get_db()
db.execute(
"UPDATE sessions SET messages = ?, last_access = ? WHERE id = ?",
(json.dumps(messages), time.time(), session_id),
)
db.commit()
def process_openai_response(subscription_id: int, messages: List[Dict[str, Any]]) -> Any:
"""Process OpenAI response and handle tool calls if needed."""
db = get_db()
response = client.chat.completions.create(
model="gpt-4-turbo",
messages=messages, # type: ignore
tools=get_tools(), # type: ignore
)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
log.debug("Received response: %s", response_message.content)
if tool_calls:
log.info(
"Tool calls requested: %s",
[f"{tc.function.name}({tc.function.arguments})" for tc in tool_calls],
)
messages.append(response_message.model_dump()) # Convert to dict
with app.app_context():
tool_outputs = run_tools(db, subscription_id, tool_calls)
for tool_output in tool_outputs:
messages.append(tool_output)
second_response = client.chat.completions.create(
model="gpt-4-turbo",
messages=messages, # type: ignore
)
response_message = second_response.choices[0].message
messages.append(response_message.model_dump()) # Convert to dict
return response_message
def cleanup_expired_sessions() -> None:
"""Remove expired sessions from the database."""
cutoff_time = time.time() - SESSION_TIMEOUT_SECONDS
db = get_db()
cur = db.execute("DELETE FROM sessions WHERE last_access < ?", (cutoff_time,))
deleted_count = cur.rowcount
db.commit()
if deleted_count > 0:
log.info("Cleaned up %s expired sessions", deleted_count)
# ------------------ ROUTES ------------------
@app.route("/agent", methods=["POST"])
def agent() -> Union[Response, Tuple[Response, int]]:
data: Dict[str, Any] = request.json or {}
user_query: Optional[str] = data.get("query")
session_id: Optional[str] = data.get("session_id")
subscription_id = g.get("subscription_id")
if not user_query:
return jsonify({"error": "Query is required"}), 400
try:
# Get or create session and retrieve conversation history
session_id, messages = get_or_create_session(session_id, time.time(), subscription_id)
# Add user query to conversation
messages.append({"role": "user", "content": user_query})
# Process OpenAI response and handle tool calls
response_message = process_openai_response(subscription_id, messages)
# Update session with new messages
update_session_messages(session_id, messages)
return jsonify({"response": response_message.content, "session_id": session_id})
except Exception as e:
log.error("Error in OpenAI integration: %s", e)
return jsonify({"error": str(e)}), 500
# ------------------ MAIN ------------------
if __name__ == "__main__":
with app.app_context():
cleanup_expired_sessions() # Clean up any expired sessions on startup
app.run(host="0.0.0.0", port=5000)