-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
211 lines (173 loc) · 7.57 KB
/
main.py
File metadata and controls
211 lines (173 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import json
import logging
import os
import re
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from sqlalchemy import text
from app.qif_indexer import QIFIndexer
# Configure logging
default_level = os.getenv("LOG_LEVEL", "INFO")
logging.basicConfig(level=getattr(logging, default_level))
logger = logging.getLogger(__name__)
# Environment variables
qif_dir = os.getenv("QIF_DIR", "/qifs")
db_path = os.getenv("DB_PATH", "/db/transactions.db")
ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
ollama_model = os.getenv("OLLAMA_MODEL", "phi4-mini:3.8b")
# Create FastAPI app
app = FastAPI()
# Ensure database
indexer = QIFIndexer(qif_dir, db_path)
indexer.ensure_database()
logger.info("Database ready at %s", db_path)
def format_markdown_table(rows):
if not rows:
return "No results found."
headers = list(rows[0].keys())
header_line = "| " + " | ".join(headers) + " |"
sep_line = "| " + " | ".join(["---"] * len(headers)) + " |"
body = "\n".join("| " + " | ".join(str(row[k]) for k in headers) + " |" for row in rows)
return "\n".join([header_line, sep_line, body])
def format_human_readable(rows):
if not rows:
return "No results found."
if len(rows) == 1 and len(rows[0]) == 1:
key, value = list(rows[0].items())[0]
return f"The {key.replace('_', ' ')} is {value}."
return format_markdown_table(rows)
def sanitize_llm_sql(raw_sql: str) -> str:
"""Allow only a single SELECT statement and block dangerous SQL operations."""
if not raw_sql:
raise HTTPException(status_code=500, detail="No SQL was generated by the LLM.")
if ";" in raw_sql.strip().rstrip(";"):
raise HTTPException(status_code=500, detail="Multiple SQL statements are not allowed.")
sql = re.sub(r"```sql\s*", "", raw_sql, flags=re.IGNORECASE)
sql = re.sub(r"```", "", sql).strip().rstrip(";")
if not sql:
raise HTTPException(status_code=500, detail="No SQL was generated by the LLM.")
lowered = sql.lower()
disallowed = ["insert", "update", "delete", "drop", "alter", "pragma", "attach", "vacuum"]
if not lowered.startswith("select"):
raise HTTPException(status_code=500, detail=f"Only SELECT is allowed. SQL: {sql}")
if any(re.search(rf"\b{keyword}\b", lowered) for keyword in disallowed):
raise HTTPException(status_code=500, detail="Generated SQL contains blocked keywords.")
return sql
def generate_sql(question: str) -> str:
schema = "transactions(date DATE, payee TEXT, category TEXT, memo TEXT, amount REAL)"
prompt = (
"You are a SQLite SQL expert. Only return a valid SQLite SELECT statement for the question below, "
f"using this schema:\n{schema}\n"
"Use table name 'transactions'. Use strftime('%Y', date) for year filtering. "
"Never use markdown or code fences. Never return non-SQL text.\n"
f"Question: {question}\nSQL:"
)
try:
with requests.post(
f"{ollama_url}/api/generate",
json={"model": ollama_model, "prompt": prompt},
stream=True,
timeout=(5, 60),
) as response:
if response.status_code != 200:
raise HTTPException(status_code=500, detail=f"Ollama error: {response.text}")
raw_sql = ""
for line in response.iter_lines():
if not line:
continue
line_decoded = line.decode("utf-8")
logger.debug("Raw line from LLM: %s", line_decoded)
try:
obj = json.loads(line_decoded)
raw_sql += obj.get("response", "")
except Exception as e:
logger.warning("Failed to parse JSON: %s | Line: %s", e, line_decoded)
logger.info("Raw SQL from LLM before cleanup: %s", raw_sql)
return sanitize_llm_sql(raw_sql)
except requests.RequestException as exc:
logger.exception("Failed to query Ollama")
raise HTTPException(status_code=503, detail=f"Failed to query Ollama: {exc}")
class Query(BaseModel):
question: str
@app.get("/transactions/{year}")
async def list_transactions(year: int):
try:
q = text(
"SELECT date, payee, category, memo, amount "
"FROM transactions WHERE strftime('%Y', date)=:yr"
)
with indexer.engine.connect() as conn:
rows_db = conn.execute(q, {"yr": f"{year}"}).fetchall()
if not rows_db:
logger.info("No transactions found for year %s", year)
return {"transactions": []}
rows = []
for row in rows_db:
r = dict(row._mapping)
date_val = r.get("date")
r["date"] = date_val.isoformat() if hasattr(date_val, "isoformat") else (str(date_val) if date_val else None)
r["amount"] = f"${r['amount']:,.2f}" if r["amount"] is not None else None
rows.append(r)
logger.info("Returned %s transactions for year %s", len(rows), year)
return {"transactions": rows}
except Exception as e:
logger.exception("Transaction listing error")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
try:
response = requests.get(f"{ollama_url}/api/tags", timeout=5)
response.raise_for_status()
logger.info("Health check OK")
return {"status": "ok"}
except Exception as e:
logger.error("Health check failed: %s", e)
raise HTTPException(status_code=503, detail=str(e))
@app.get("/count")
async def count_transactions():
try:
with indexer.engine.connect() as conn:
row = conn.execute(text("SELECT COUNT(*) as cnt FROM transactions")).fetchone()
count = row[0] if row is not None else 0
logger.info("Total transactions count: %s", count)
return {"count": count}
except Exception as e:
logger.exception("Error counting transactions")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/transactions/count/{year}")
async def count_transactions_year(year: int):
try:
with indexer.engine.connect() as conn:
row = conn.execute(
text("SELECT COUNT(*) FROM transactions WHERE strftime('%Y', date)=:yr"),
{"yr": f"{year}"},
).fetchone()
count = row[0] if row is not None else 0
logger.info("Transactions count for %s: %s", year, count)
return {"year": year, "count": count}
except Exception as e:
logger.exception("Error counting transactions for year %s", year)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat")
async def chat(query: Query):
user_question = query.question.strip()
if not user_question:
raise HTTPException(status_code=400, detail="Question cannot be empty.")
sql = generate_sql(user_question)
try:
with indexer.engine.connect() as conn:
result = conn.execute(text(sql))
rows = []
for row in result:
d = dict(row._mapping)
date_val = d.get("date")
d["date"] = date_val.isoformat() if hasattr(date_val, "isoformat") else (str(date_val) if date_val else None)
amt = d.get("amount")
if isinstance(amt, (int, float)):
d["amount"] = f"${amt:,.2f}"
rows.append(d)
return {"answer": format_human_readable(rows)}
except Exception as e:
logger.exception("SQL execution error")
raise HTTPException(status_code=500, detail=f"SQL execution failed: {e}")