|
12 | 12 | from db import JsonFileStore |
13 | 13 | from graph.nodes.chain_node import BaseChainNode |
14 | 14 | from apps.text_generator import ReportGenerator |
| 15 | +from apps.graph_analyzer import LLMGraphAnalyzer |
15 | 16 | from models import set_llm_model, DefaultEmbedding, DEFAULT_LLM_MODEL_CONFIG |
16 | 17 |
|
17 | 18 | from threading import Thread |
@@ -128,6 +129,7 @@ def __init__( |
128 | 129 | self.llm_report() |
129 | 130 |
|
130 | 131 | self.text_generator = ReportGenerator(self.llm.model) |
| 132 | + self.graph_analyzer = LLMGraphAnalyzer(self.llm.model) |
131 | 133 |
|
132 | 134 | def get_persist_store(self, dataset_id): |
133 | 135 | persist_store = self.persist_stores.setdefault( |
@@ -1228,6 +1230,111 @@ def do_summarizing(): |
1228 | 1230 | traceback.print_exc() |
1229 | 1231 | return create_error_response(str(e)), 500 |
1230 | 1232 |
|
| 1233 | + @self.app.route("/api/llm/analyze/fetch", methods=["POST"]) |
| 1234 | + def fetch_query(): |
| 1235 | + try: |
| 1236 | + # Extract the JSON payload |
| 1237 | + input_data = request.get_json() |
| 1238 | + required_fields = [ |
| 1239 | + "query", |
| 1240 | + "schema", |
| 1241 | + ] |
| 1242 | + for field in required_fields: |
| 1243 | + if field not in input_data: |
| 1244 | + return ( |
| 1245 | + create_error_response(f"Missing {field} in request"), |
| 1246 | + 400, |
| 1247 | + ) |
| 1248 | + |
| 1249 | + query = input_data["query"] |
| 1250 | + schema = input_data["schema"] |
| 1251 | + lang = input_data.get("lang", "cypher") |
| 1252 | + |
| 1253 | + output_prompt = self.graph_analyzer.get_fetch_query( |
| 1254 | + query=query, schema=schema, lang=lang |
| 1255 | + ) |
| 1256 | + |
| 1257 | + # Return success response |
| 1258 | + return ( |
| 1259 | + create_json_response({"prompts": output_prompt}), |
| 1260 | + 200, |
| 1261 | + ) |
| 1262 | + |
| 1263 | + except Exception as e: |
| 1264 | + traceback.print_exc() |
| 1265 | + return create_error_response(str(e)), 500 |
| 1266 | + |
| 1267 | + @self.app.route("/api/llm/analyze/mindmap", methods=["POST"]) |
| 1268 | + def get_mind_map(): |
| 1269 | + try: |
| 1270 | + # Extract the JSON payload |
| 1271 | + input_data = request.get_json() |
| 1272 | + required_fields = [ |
| 1273 | + "query", |
| 1274 | + "data", |
| 1275 | + ] |
| 1276 | + for field in required_fields: |
| 1277 | + if field not in input_data: |
| 1278 | + return ( |
| 1279 | + create_error_response(f"Missing {field} in request"), |
| 1280 | + 400, |
| 1281 | + ) |
| 1282 | + |
| 1283 | + query = input_data["query"] |
| 1284 | + data = input_data["data"] |
| 1285 | + |
| 1286 | + output_json = self.graph_analyzer.get_mind_map(query=query, data=data) |
| 1287 | + |
| 1288 | + # Return success response |
| 1289 | + return ( |
| 1290 | + create_json_response({"mind_map": output_json}), |
| 1291 | + 200, |
| 1292 | + ) |
| 1293 | + |
| 1294 | + except Exception as e: |
| 1295 | + traceback.print_exc() |
| 1296 | + return create_error_response(str(e)), 500 |
| 1297 | + |
| 1298 | + @self.app.route("/api/llm/analyze/writereport", methods=["POST"]) |
| 1299 | + def get_report(): |
| 1300 | + try: |
| 1301 | + # Extract the JSON payload |
| 1302 | + input_data = request.get_json() |
| 1303 | + required_fields = [ |
| 1304 | + "query", |
| 1305 | + "mind_map", |
| 1306 | + ] |
| 1307 | + for field in required_fields: |
| 1308 | + if field not in input_data: |
| 1309 | + return ( |
| 1310 | + create_error_response(f"Missing {field} in request"), |
| 1311 | + 400, |
| 1312 | + ) |
| 1313 | + |
| 1314 | + query = input_data["query"] |
| 1315 | + mind_map = input_data["mind_map"] |
| 1316 | + max_token_per_subsection = input_data.get( |
| 1317 | + "max_token_per_subsection", 100 |
| 1318 | + ) |
| 1319 | + bib2id = input_data.get("bib2id", {}) |
| 1320 | + |
| 1321 | + output_text = self.graph_analyzer.write_report_sec_by_sec( |
| 1322 | + query=query, |
| 1323 | + mind_map=mind_map, |
| 1324 | + max_token_per_subsection=max_token_per_subsection, |
| 1325 | + bib2id=bib2id, |
| 1326 | + ) |
| 1327 | + |
| 1328 | + # Return success response |
| 1329 | + return ( |
| 1330 | + create_json_response({"text": output_text}), |
| 1331 | + 200, |
| 1332 | + ) |
| 1333 | + |
| 1334 | + except Exception as e: |
| 1335 | + traceback.print_exc() |
| 1336 | + return create_error_response(str(e)), 500 |
| 1337 | + |
1231 | 1338 | @self.app.route("/api/llm/report/prepare", methods=["POST"]) |
1232 | 1339 | def prepare_report(): |
1233 | 1340 | try: |
|
0 commit comments