Skip to content

Commit 9f31b50

Browse files
author
louyk18
committed
write by sec
1 parent eef8484 commit 9f31b50

File tree

6 files changed

+416
-81
lines changed

6 files changed

+416
-81
lines changed

python/graphy/apps/demo_app.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from db import JsonFileStore
1313
from graph.nodes.chain_node import BaseChainNode
1414
from apps.text_generator import ReportGenerator
15+
from apps.graph_analyzer import LLMGraphAnalyzer
1516
from models import set_llm_model, DefaultEmbedding, DEFAULT_LLM_MODEL_CONFIG
1617

1718
from threading import Thread
@@ -128,6 +129,7 @@ def __init__(
128129
self.llm_report()
129130

130131
self.text_generator = ReportGenerator(self.llm.model)
132+
self.graph_analyzer = LLMGraphAnalyzer(self.llm.model)
131133

132134
def get_persist_store(self, dataset_id):
133135
persist_store = self.persist_stores.setdefault(
@@ -1228,6 +1230,111 @@ def do_summarizing():
12281230
traceback.print_exc()
12291231
return create_error_response(str(e)), 500
12301232

1233+
@self.app.route("/api/llm/analyze/fetch", methods=["POST"])
1234+
def fetch_query():
1235+
try:
1236+
# Extract the JSON payload
1237+
input_data = request.get_json()
1238+
required_fields = [
1239+
"query",
1240+
"schema",
1241+
]
1242+
for field in required_fields:
1243+
if field not in input_data:
1244+
return (
1245+
create_error_response(f"Missing {field} in request"),
1246+
400,
1247+
)
1248+
1249+
query = input_data["query"]
1250+
schema = input_data["schema"]
1251+
lang = input_data.get("lang", "cypher")
1252+
1253+
output_prompt = self.graph_analyzer.get_fetch_query(
1254+
query=query, schema=schema, lang=lang
1255+
)
1256+
1257+
# Return success response
1258+
return (
1259+
create_json_response({"prompts": output_prompt}),
1260+
200,
1261+
)
1262+
1263+
except Exception as e:
1264+
traceback.print_exc()
1265+
return create_error_response(str(e)), 500
1266+
1267+
@self.app.route("/api/llm/analyze/mindmap", methods=["POST"])
1268+
def get_mind_map():
1269+
try:
1270+
# Extract the JSON payload
1271+
input_data = request.get_json()
1272+
required_fields = [
1273+
"query",
1274+
"data",
1275+
]
1276+
for field in required_fields:
1277+
if field not in input_data:
1278+
return (
1279+
create_error_response(f"Missing {field} in request"),
1280+
400,
1281+
)
1282+
1283+
query = input_data["query"]
1284+
data = input_data["data"]
1285+
1286+
output_json = self.graph_analyzer.get_mind_map(query=query, data=data)
1287+
1288+
# Return success response
1289+
return (
1290+
create_json_response({"mind_map": output_json}),
1291+
200,
1292+
)
1293+
1294+
except Exception as e:
1295+
traceback.print_exc()
1296+
return create_error_response(str(e)), 500
1297+
1298+
@self.app.route("/api/llm/analyze/writereport", methods=["POST"])
1299+
def get_report():
1300+
try:
1301+
# Extract the JSON payload
1302+
input_data = request.get_json()
1303+
required_fields = [
1304+
"query",
1305+
"mind_map",
1306+
]
1307+
for field in required_fields:
1308+
if field not in input_data:
1309+
return (
1310+
create_error_response(f"Missing {field} in request"),
1311+
400,
1312+
)
1313+
1314+
query = input_data["query"]
1315+
mind_map = input_data["mind_map"]
1316+
max_token_per_subsection = input_data.get(
1317+
"max_token_per_subsection", 100
1318+
)
1319+
bib2id = input_data.get("bib2id", {})
1320+
1321+
output_text = self.graph_analyzer.write_report_sec_by_sec(
1322+
query=query,
1323+
mind_map=mind_map,
1324+
max_token_per_subsection=max_token_per_subsection,
1325+
bib2id=bib2id,
1326+
)
1327+
1328+
# Return success response
1329+
return (
1330+
create_json_response({"text": output_text}),
1331+
200,
1332+
)
1333+
1334+
except Exception as e:
1335+
traceback.print_exc()
1336+
return create_error_response(str(e)), 500
1337+
12311338
@self.app.route("/api/llm/report/prepare", methods=["POST"])
12321339
def prepare_report():
12331340
try:

python/graphy/apps/graph_analyzer.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
TEMPLATE_QUERY_GENERATOR,
1818
TEMPLATE_MIND_MAP_GENERATOR,
1919
TEMPLATE_RELATED_WORK_GENERATOR,
20+
TEMPLATE_RELATED_WORK_INTRO_GENERATOR,
21+
TEMPLATE_RELATED_WORK_TEXT_PROMPT,
22+
TEMPLATE_TEXT_EXAMPLE_PROMPT,
23+
TEMPLATE_SUBSECTION_INSTRUCTION_PROMPT,
24+
TEMPLATE_PREVIOUS_SUBSECITON_PROMPT,
2025
)
2126

2227
logger = logging.getLogger()
@@ -93,6 +98,78 @@ def write_report(self, query, mind_map, max_token_per_subsection):
9398

9499
return result
95100

101+
def generate_section_texts(self, query, mind_map, max_token_per_subsection):
102+
section_prompts = []
103+
104+
subsection_id = 0
105+
for category in mind_map.get("data", []):
106+
subsection_id += 1
107+
prop_slot = str(category)
108+
109+
generated_instruction = ""
110+
if subsection_id == 1:
111+
generated_instruction += TEMPLATE_TEXT_EXAMPLE_PROMPT
112+
113+
generated_instruction += TEMPLATE_SUBSECTION_INSTRUCTION_PROMPT.format(
114+
subsection_id=str(subsection_id),
115+
max_token_per_subsection=str(max_token_per_subsection),
116+
)
117+
118+
if subsection_id > 1:
119+
generated_instruction += TEMPLATE_PREVIOUS_SUBSECITON_PROMPT
120+
121+
paper_memories = TEMPLATE_RELATED_WORK_TEXT_PROMPT.format(
122+
user_query=query,
123+
prop_slot=prop_slot,
124+
generate_instruction=generated_instruction,
125+
)
126+
127+
section_prompts.append(paper_memories)
128+
129+
return section_prompts
130+
131+
def write_report_sec_by_sec(
132+
self, query, mind_map, max_token_per_subsection, bib2id={}
133+
):
134+
prop_slot = ""
135+
for category in mind_map.get("data", []):
136+
if len(prop_slot) > 0:
137+
prop_slot += ","
138+
name = category.get("name", "")
139+
description = category.get("description", "")
140+
prop_slot += json.dumps({"name": name, "description": description})
141+
142+
intro_prompt = TEMPLATE_RELATED_WORK_INTRO_GENERATOR.format(prop_slot=prop_slot)
143+
intro_text = self.generate("get_report_intro", intro_prompt)
144+
145+
# print("########## INTRO PROMPT ###############")
146+
# print(intro_prompt)
147+
148+
section_prompts = self.generate_section_texts(
149+
query=query,
150+
mind_map=mind_map,
151+
max_token_per_subsection=max_token_per_subsection,
152+
)
153+
154+
# for sec in section_prompts:
155+
# print("########## SECTION PROMPT ###############")
156+
# print(sec)
157+
158+
section_text = ""
159+
for i in range(len(section_prompts)):
160+
text_prompt = section_prompts[i]
161+
if "<PREVIOUS></PREVIOUS>" in text_prompt:
162+
text_prompt = text_prompt.replace("<PREVIOUS></PREVIOUS>", section_text)
163+
new_section_text = self.generate("query_report_text", text_prompt)
164+
section_text += "\n" + new_section_text + "\n"
165+
166+
final_section = intro_text + "\n" + section_text
167+
bib_text = self.append_bib_text(final_section, bib2id)
168+
169+
final_section += bib_text
170+
171+
return final_section
172+
96173
def append_bib_text(self, text, id2bib):
97174
bib_text = ""
98175
cited_papers = set()

0 commit comments

Comments
 (0)