Skip to content

Commit 04e928c

Browse files
fix(webui): sync gradio demo
1 parent 3137c4b commit 04e928c

File tree

7 files changed

+115
-102
lines changed

7 files changed

+115
-102
lines changed

.github/workflows/push-to-hf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
[[ -d hf-repo ]] && rm -rf hf-repo
4444
git clone https://huggingface.co/${HF_REPO_TYPE}/${HF_REPO_ID} hf-repo
4545
46-
rsync -a --delete --exclude='.git' ./ hf-repo/ || true
46+
rsync -a --delete --exclude='.git' --exclude='hf-repo' ./ hf-repo/
4747
4848
cd hf-repo
4949
git add .

graphgen/graphgen.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
judge_statement,
2424
quiz,
2525
search_all,
26-
traverse_graph_atomically,
27-
traverse_graph_by_edge,
26+
traverse_graph_for_aggregated,
27+
traverse_graph_for_atomic,
2828
traverse_graph_for_multi_hop,
2929
)
3030
from .utils import (
@@ -69,6 +69,7 @@ def __post_init__(self):
6969
self.tokenizer_instance: Tokenizer = Tokenizer(
7070
model_name=self.config["tokenizer"]
7171
)
72+
print(os.getenv("SYNTHESIZER_MODEL"), os.getenv("SYNTHESIZER_API_KEY"))
7273
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
7374
model_name=os.getenv("SYNTHESIZER_MODEL"),
7475
api_key=os.getenv("SYNTHESIZER_API_KEY"),
@@ -326,7 +327,7 @@ async def async_traverse(self):
326327
output_data_type = self.config["output_data_type"]
327328

328329
if output_data_type == "atomic":
329-
results = await traverse_graph_atomically(
330+
results = await traverse_graph_for_atomic(
330331
self.synthesizer_llm_client,
331332
self.tokenizer_instance,
332333
self.graph_storage,
@@ -344,7 +345,7 @@ async def async_traverse(self):
344345
self.progress_bar,
345346
)
346347
elif output_data_type == "aggregated":
347-
results = await traverse_graph_by_edge(
348+
results = await traverse_graph_for_aggregated(
348349
self.synthesizer_llm_client,
349350
self.tokenizer_instance,
350351
self.graph_storage,

graphgen/operators/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from .judge import judge_statement
66
from .quiz import quiz
77
from .traverse_graph import (
8-
traverse_graph_atomically,
9-
traverse_graph_by_edge,
8+
traverse_graph_for_aggregated,
9+
traverse_graph_for_atomic,
1010
traverse_graph_for_multi_hop,
1111
)
1212

@@ -15,8 +15,8 @@
1515
"quiz",
1616
"judge_statement",
1717
"search_all",
18-
"traverse_graph_by_edge",
19-
"traverse_graph_atomically",
18+
"traverse_graph_for_aggregated",
19+
"traverse_graph_for_atomic",
2020
"traverse_graph_for_multi_hop",
2121
"generate_cot",
2222
]

graphgen/operators/traverse_graph.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _post_process_synthetic_data(data):
158158
return qas
159159

160160

161-
async def traverse_graph_by_edge(
161+
async def traverse_graph_for_aggregated(
162162
llm_client: OpenAIModel,
163163
tokenizer: Tokenizer,
164164
graph_storage: NetworkXStorage,
@@ -251,7 +251,6 @@ async def _process_single_batch(
251251
qas = _post_process_synthetic_data(content)
252252

253253
if len(qas) == 0:
254-
print(content)
255254
logger.error(
256255
"Error occurred while processing batch, question or answer is None"
257256
)
@@ -307,7 +306,8 @@ async def _process_single_batch(
307306
return results
308307

309308

310-
async def traverse_graph_atomically(
309+
# pylint: disable=too-many-branches, too-many-statements
310+
async def traverse_graph_for_atomic(
311311
llm_client: OpenAIModel,
312312
tokenizer: Tokenizer,
313313
graph_storage: NetworkXStorage,
@@ -328,17 +328,28 @@ async def traverse_graph_atomically(
328328
:param max_concurrent
329329
:return: question and answer
330330
"""
331-
assert traverse_strategy.qa_form == "atomic"
332331

332+
assert traverse_strategy.qa_form == "atomic"
333333
semaphore = asyncio.Semaphore(max_concurrent)
334334

335+
def _parse_qa(qa: str) -> tuple:
336+
if "Question:" in qa and "Answer:" in qa:
337+
question = qa.split("Question:")[1].split("Answer:")[0].strip()
338+
answer = qa.split("Answer:")[1].strip()
339+
elif "问题:" in qa and "答案:" in qa:
340+
question = qa.split("问题:")[1].split("答案:")[0].strip()
341+
answer = qa.split("答案:")[1].strip()
342+
else:
343+
return None, None
344+
return question.strip('"'), answer.strip('"')
345+
335346
async def _generate_question(node_or_edge: tuple):
336347
if len(node_or_edge) == 2:
337348
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
338-
loss = node_or_edge[1]["loss"]
349+
loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
339350
else:
340351
des = node_or_edge[2]["description"]
341-
loss = node_or_edge[2]["loss"]
352+
loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
342353

343354
async with semaphore:
344355
try:
@@ -350,13 +361,8 @@ async def _generate_question(node_or_edge: tuple):
350361
)
351362
)
352363

353-
if "Question:" in qa and "Answer:" in qa:
354-
question = qa.split("Question:")[1].split("Answer:")[0].strip()
355-
answer = qa.split("Answer:")[1].strip()
356-
elif "问题:" in qa and "答案:" in qa:
357-
question = qa.split("问题:")[1].split("答案:")[0].strip()
358-
answer = qa.split("答案:")[1].strip()
359-
else:
364+
question, answer = _parse_qa(qa)
365+
if question is None or answer is None:
360366
return {}
361367

362368
question = question.strip('"')
@@ -386,16 +392,18 @@ async def _generate_question(node_or_edge: tuple):
386392
if "<SEP>" in node[1]["description"]:
387393
description_list = node[1]["description"].split("<SEP>")
388394
for item in description_list:
389-
tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
395+
tasks.append((node[0], {"description": item}))
396+
if "loss" in node[1]:
397+
tasks[-1][1]["loss"] = node[1]["loss"]
390398
else:
391399
tasks.append((node[0], node[1]))
392400
for edge in edges:
393401
if "<SEP>" in edge[2]["description"]:
394402
description_list = edge[2]["description"].split("<SEP>")
395403
for item in description_list:
396-
tasks.append(
397-
(edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
398-
)
404+
tasks.append((edge[0], edge[1], {"description": item}))
405+
if "loss" in edge[2]:
406+
tasks[-1][2]["loss"] = edge[2]["loss"]
399407
else:
400408
tasks.append((edge[0], edge[1], edge[2]))
401409

0 commit comments

Comments
 (0)