-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpatch.main.py
More file actions
353 lines (338 loc) · 13.7 KB
/
patch.main.py
File metadata and controls
353 lines (338 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
diff --git a/app/main.py b/app/main.py
index 8fb60576bd63ec3a493fee4181ec6320c40ee732..3201109e9f56819f7789b8d0fc813ba7bdf8ef5a 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,213 +1,211 @@
-# ---------------------------
-# app/main.py
-# ---------------------------
-import os
-import requests
+import json
import logging
+import os
import re
-import json
-from fastapi import FastAPI, HTTPException, Request
+
+import requests
+from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
-from langchain_community.utilities import SQLDatabase
-from langchain_community.llms import Ollama
from sqlalchemy import text
+
from app.qif_indexer import QIFIndexer
# Configure logging
-default_level = os.getenv('LOG_LEVEL', 'INFO')
+default_level = os.getenv("LOG_LEVEL", "INFO")
logging.basicConfig(level=getattr(logging, default_level))
logger = logging.getLogger(__name__)
# Environment variables
-qif_dir = os.getenv('QIF_DIR', '/qifs')
-db_path = os.getenv('DB_PATH', '/db/transactions.db')
-ollama_url = os.getenv('OLLAMA_URL')
+qif_dir = os.getenv("QIF_DIR", "/qifs")
+db_path = os.getenv("DB_PATH", "/db/transactions.db")
+ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
+ollama_model = os.getenv("OLLAMA_MODEL", "phi4-mini:3.8b")
# Create FastAPI app
app = FastAPI()
# Ensure database
indexer = QIFIndexer(qif_dir, db_path)
indexer.ensure_database()
-logger.info(f"Database ready at {db_path}")
+logger.info("Database ready at %s", db_path)
-# Setup LLM + SQL chain
-db_uri = f"sqlite:///{db_path}"
-db = SQLDatabase.from_uri(db_uri)
-llm = Ollama(model="phi4-mini:3.8b", base_url=ollama_url)
-logger.info("SQLDatabase and LLM initialized")
def format_markdown_table(rows):
if not rows:
return "No results found."
headers = list(rows[0].keys())
header_line = "| " + " | ".join(headers) + " |"
- sep_line = "| " + " | ".join(["---"]*len(headers)) + " |"
+ sep_line = "| " + " | ".join(["---"] * len(headers)) + " |"
body = "\n".join("| " + " | ".join(str(row[k]) for k in headers) + " |" for row in rows)
return "\n".join([header_line, sep_line, body])
-def format_human_readable(rows, sql):
+
+def format_human_readable(rows):
if not rows:
return "No results found."
- # Handle single aggregate row
if len(rows) == 1 and len(rows[0]) == 1:
key, value = list(rows[0].items())[0]
return f"The {key.replace('_', ' ')} is {value}."
- # Else, show as table
return format_markdown_table(rows)
+
+def sanitize_llm_sql(raw_sql: str) -> str:
+ """Allow only a single SELECT statement and block dangerous SQL operations."""
+ if not raw_sql:
+ raise HTTPException(status_code=500, detail="No SQL was generated by the LLM.")
+
+ if ";" in raw_sql.strip().rstrip(";"):
+ raise HTTPException(status_code=500, detail="Multiple SQL statements are not allowed.")
+
+ sql = re.sub(r"```sql\s*", "", raw_sql, flags=re.IGNORECASE)
+ sql = re.sub(r"```", "", sql).strip().rstrip(";")
+ if not sql:
+ raise HTTPException(status_code=500, detail="No SQL was generated by the LLM.")
+
+ lowered = sql.lower()
+ disallowed = ["insert", "update", "delete", "drop", "alter", "pragma", "attach", "vacuum"]
+ if not lowered.startswith("select"):
+ raise HTTPException(status_code=500, detail=f"Only SELECT is allowed. SQL: {sql}")
+ if any(re.search(rf"\b{keyword}\b", lowered) for keyword in disallowed):
+ raise HTTPException(status_code=500, detail="Generated SQL contains blocked keywords.")
+
+ return sql
+
+
+def generate_sql(question: str) -> str:
+ schema = "transactions(date DATE, payee TEXT, category TEXT, memo TEXT, amount REAL)"
+ prompt = (
+ "You are a SQLite SQL expert. Only return a valid SQLite SELECT statement for the question below, "
+ f"using this schema:\n{schema}\n"
+ "Use table name 'transactions'. Use strftime('%Y', date) for year filtering. "
+ "Never use markdown or code fences. Never return non-SQL text.\n"
+ f"Question: {question}\nSQL:"
+ )
+
+ try:
+ with requests.post(
+ f"{ollama_url}/api/generate",
+ json={"model": ollama_model, "prompt": prompt},
+ stream=True,
+ timeout=(5, 60),
+ ) as response:
+ if response.status_code != 200:
+ raise HTTPException(status_code=500, detail=f"Ollama error: {response.text}")
+
+ raw_sql = ""
+ for line in response.iter_lines():
+ if not line:
+ continue
+ line_decoded = line.decode("utf-8")
+ logger.debug("Raw line from LLM: %s", line_decoded)
+ try:
+ obj = json.loads(line_decoded)
+ raw_sql += obj.get("response", "")
+ except Exception as e:
+ logger.warning("Failed to parse JSON: %s | Line: %s", e, line_decoded)
+
+ logger.info("Raw SQL from LLM before cleanup: %s", raw_sql)
+ return sanitize_llm_sql(raw_sql)
+ except requests.RequestException as exc:
+ logger.exception("Failed to query Ollama")
+ raise HTTPException(status_code=503, detail=f"Failed to query Ollama: {exc}")
+
+
class Query(BaseModel):
question: str
-@app.get('/transactions/{year}')
+
+@app.get("/transactions/{year}")
async def list_transactions(year: int):
try:
- conn = indexer.engine.connect()
q = text(
"SELECT date, payee, category, memo, amount "
"FROM transactions WHERE strftime('%Y', date)=:yr"
)
- res = conn.execute(q, {"yr": f"{year}"})
- logger.info(f"Listing transactions for year {year}")
- if res.rowcount == 0:
- logger.warning(f"No transactions found for year {year}")
- return {'transactions': []}
- # Convert result to list of dicts
- res = res.fetchall()
- if not res:
- logger.warning(f"No transactions found for year {year}")
- return {'transactions': []}
- # Format results
- logger.info(f"Found {len(res)} transactions for year {year}")
+ with indexer.engine.connect() as conn:
+ rows_db = conn.execute(q, {"yr": f"{year}"}).fetchall()
+
+ if not rows_db:
+ logger.info("No transactions found for year %s", year)
+ return {"transactions": []}
+
rows = []
- for row in res:
+ for row in rows_db:
r = dict(row._mapping)
- date_val = r.get('date')
- if hasattr(date_val, 'isoformat'):
- r['date'] = date_val.isoformat()
- else:
- r['date'] = str(date_val) if date_val else None
-# r['date'] = r['date'].isoformat() if r['date'] else None
- r['amount'] = f"${r['amount']:,.2f}" if r['amount'] is not None else None
+ date_val = r.get("date")
+ r["date"] = date_val.isoformat() if hasattr(date_val, "isoformat") else (str(date_val) if date_val else None)
+ r["amount"] = f"${r['amount']:,.2f}" if r["amount"] is not None else None
rows.append(r)
- conn.close()
- logger.info(f"Returned {len(rows)} transactions for year {year}")
- return {'transactions': rows}
+
+ logger.info("Returned %s transactions for year %s", len(rows), year)
+ return {"transactions": rows}
except Exception as e:
logger.exception("Transaction listing error")
raise HTTPException(status_code=500, detail=str(e))
-@app.get('/health')
+
+@app.get("/health")
def health_check():
try:
- r = requests.get(f"{ollama_url}/v1/models")
- r.raise_for_status()
+ response = requests.get(f"{ollama_url}/api/tags", timeout=5)
+ response.raise_for_status()
logger.info("Health check OK")
- return {'status': 'ok'}
+ return {"status": "ok"}
except Exception as e:
- logger.error(f"Health check failed: {e}")
+ logger.error("Health check failed: %s", e)
raise HTTPException(status_code=503, detail=str(e))
-@app.get('/count')
+
+@app.get("/count")
async def count_transactions():
- """Return total number of transactions in the database."""
try:
- conn = indexer.engine.connect()
- result = conn.execute(text("SELECT COUNT(*) as cnt FROM transactions"))
- row = result.fetchone()
+ with indexer.engine.connect() as conn:
+ row = conn.execute(text("SELECT COUNT(*) as cnt FROM transactions")).fetchone()
count = row[0] if row is not None else 0
- conn.close()
- logger.info(f"Total transactions count: {count}")
- return {'count': count}
+ logger.info("Total transactions count: %s", count)
+ return {"count": count}
except Exception as e:
logger.exception("Error counting transactions")
raise HTTPException(status_code=500, detail=str(e))
-
- # New endpoint: count transactions for a given year
-@app.get('/transactions/count/{year}')
+
+
+@app.get("/transactions/count/{year}")
async def count_transactions_year(year: int):
- """Return the number of transactions for the specified year."""
try:
- conn = indexer.engine.connect()
- result = conn.execute(text(
- "SELECT COUNT(*) FROM transactions "
- "WHERE strftime('%Y', date)=:yr"
- ), {"yr": f"{year}"})
- row = result.fetchone()
+ with indexer.engine.connect() as conn:
+ row = conn.execute(
+ text("SELECT COUNT(*) FROM transactions WHERE strftime('%Y', date)=:yr"),
+ {"yr": f"{year}"},
+ ).fetchone()
count = row[0] if row is not None else 0
- conn.close()
- logger.info(f"Transactions count for {year}: {count}")
- return {'year': year, 'count': count}
+ logger.info("Transactions count for %s: %s", year, count)
+ return {"year": year, "count": count}
except Exception as e:
- logger.exception(f"Error counting transactions for year {year}")
+ logger.exception("Error counting transactions for year %s", year)
raise HTTPException(status_code=500, detail=str(e))
-@app.post('/chat')
-async def chat(query: dict):
- user_question = query['question']
- schema = "transactions(date DATE, payee TEXT, category TEXT, memo TEXT, amount REAL)"
- # Stronger prompt to encourage correct output
- prompt = (
- f"You are a SQLite SQL expert. Only return a valid SQLite SELECT statement for the question below, using this schema:\n"
- f"{schema}\n"
- f"Never use YEAR() or transaction_date. Use strftime('%Y', date) for filtering years. Table name is lowercase 'transactions'.\n"
- f"Do not add markdown or code fences. Do not explain anything, only return SQL.\n"
- f"Question: {user_question}\n"
- f"SQL:"
- )
-
- ollama_url = os.getenv('OLLAMA_URL', 'http://localhost:11434')
- model = "phi4-mini:3.8b"
+@app.post("/chat")
+async def chat(query: Query):
+ user_question = query.question.strip()
+ if not user_question:
+ raise HTTPException(status_code=400, detail="Question cannot be empty.")
- response = requests.post(
- f"{ollama_url}/api/generate",
- json={"model": model, "prompt": prompt},
- stream=True
- )
- if response.status_code != 200:
- raise HTTPException(status_code=500, detail=f"Ollama error: {response.text}")
-
- # Accumulate the streamed 'response' fields
- sql = ""
- for line in response.iter_lines():
- if line:
- line_decoded = line.decode('utf-8') # decode from bytes to string
- logger.debug(f"Raw line from LLM: {line_decoded}")
- try:
- obj = json.loads(line_decoded)
- sql += obj.get("response", "")
- except Exception as e:
- logger.warning(f"Failed to parse JSON: {e} | Line: {line_decoded}")
- continue
-
- # Remove code fences/markdown just in case
- logger.info(f"Raw SQL from LLM before cleanup: {sql}")
- sql = re.sub(r'```sql\\s*', '', sql, flags=re.IGNORECASE)
- sql = re.sub(r'```', '', sql)
- sql = sql.strip().strip(';')
- logger.debug(f"Raw SQL from LLM: {sql}")
-
- if not sql or not sql.lower().startswith("select"):
- raise HTTPException(status_code=500, detail=f"No valid SQL was generated by the LLM. SQL: {sql}")
+ sql = generate_sql(user_question)
try:
- conn = indexer.engine.connect()
- result = conn.execute(text(sql))
- rows = []
- for row in result:
- d = dict(row._mapping)
- date_val = d.get('date')
- if hasattr(date_val, 'isoformat'):
- d['date'] = date_val.isoformat()
- else:
- d['date'] = str(date_val) if date_val else None
- amt = d.get('amount')
- if isinstance(amt, (int, float)):
- d['amount'] = f"${amt:,.2f}"
- rows.append(d)
- conn.close()
- return {'answer': format_human_readable(rows, sql)}
+ with indexer.engine.connect() as conn:
+ result = conn.execute(text(sql))
+ rows = []
+ for row in result:
+ d = dict(row._mapping)
+ date_val = d.get("date")
+ d["date"] = date_val.isoformat() if hasattr(date_val, "isoformat") else (str(date_val) if date_val else None)
+ amt = d.get("amount")
+ if isinstance(amt, (int, float)):
+ d["amount"] = f"${amt:,.2f}"
+ rows.append(d)
+ return {"answer": format_human_readable(rows)}
except Exception as e:
logger.exception("SQL execution error")
raise HTTPException(status_code=500, detail=f"SQL execution failed: {e}")