Skip to content

Commit 7ac6618

Browse files
Merge pull request #47 from open-sciencelab/sync-webui
fix(webui): sync gradio demo
2 parents 3137c4b + dbdb541 commit 7ac6618

File tree

8 files changed

+120
-105
lines changed

8 files changed

+120
-105
lines changed

.github/workflows/push-to-hf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
[[ -d hf-repo ]] && rm -rf hf-repo
4444
git clone https://huggingface.co/${HF_REPO_TYPE}/${HF_REPO_ID} hf-repo
4545
46-
rsync -a --delete --exclude='.git' ./ hf-repo/ || true
46+
rsync -a --delete --exclude='.git' --exclude='hf-repo' ./ hf-repo/
4747
4848
cd hf-repo
4949
git add .

graphgen/configs/multi_hop_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
99
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10-
enabled: true
10+
enabled: false
1111
quiz_samples: 2 # number of quiz samples to generate
1212
re_judge: false # whether to re-judge the existing quiz samples
1313
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss

graphgen/graphgen.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
judge_statement,
2424
quiz,
2525
search_all,
26-
traverse_graph_atomically,
27-
traverse_graph_by_edge,
26+
traverse_graph_for_aggregated,
27+
traverse_graph_for_atomic,
2828
traverse_graph_for_multi_hop,
2929
)
3030
from .utils import (
@@ -69,6 +69,7 @@ def __post_init__(self):
6969
self.tokenizer_instance: Tokenizer = Tokenizer(
7070
model_name=self.config["tokenizer"]
7171
)
72+
print(os.getenv("SYNTHESIZER_MODEL"), os.getenv("SYNTHESIZER_API_KEY"))
7273
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
7374
model_name=os.getenv("SYNTHESIZER_MODEL"),
7475
api_key=os.getenv("SYNTHESIZER_API_KEY"),
@@ -326,7 +327,7 @@ async def async_traverse(self):
326327
output_data_type = self.config["output_data_type"]
327328

328329
if output_data_type == "atomic":
329-
results = await traverse_graph_atomically(
330+
results = await traverse_graph_for_atomic(
330331
self.synthesizer_llm_client,
331332
self.tokenizer_instance,
332333
self.graph_storage,
@@ -344,7 +345,7 @@ async def async_traverse(self):
344345
self.progress_bar,
345346
)
346347
elif output_data_type == "aggregated":
347-
results = await traverse_graph_by_edge(
348+
results = await traverse_graph_for_aggregated(
348349
self.synthesizer_llm_client,
349350
self.tokenizer_instance,
350351
self.graph_storage,

graphgen/operators/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from .judge import judge_statement
66
from .quiz import quiz
77
from .traverse_graph import (
8-
traverse_graph_atomically,
9-
traverse_graph_by_edge,
8+
traverse_graph_for_aggregated,
9+
traverse_graph_for_atomic,
1010
traverse_graph_for_multi_hop,
1111
)
1212

@@ -15,8 +15,8 @@
1515
"quiz",
1616
"judge_statement",
1717
"search_all",
18-
"traverse_graph_by_edge",
19-
"traverse_graph_atomically",
18+
"traverse_graph_for_aggregated",
19+
"traverse_graph_for_atomic",
2020
"traverse_graph_for_multi_hop",
2121
"generate_cot",
2222
]

graphgen/operators/traverse_graph.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def get_average_loss(batch: tuple, loss_strategy: str) -> float:
135135
) / (len(batch[0]) + len(batch[1]))
136136
raise ValueError("Invalid loss strategy")
137137
except Exception as e: # pylint: disable=broad-except
138-
logger.error("Error calculating average loss: %s", e)
138+
logger.warning(
139+
"Loss not found in some nodes or edges, setting loss to -1.0: %s", e
140+
)
139141
return -1.0
140142

141143

@@ -158,7 +160,7 @@ def _post_process_synthetic_data(data):
158160
return qas
159161

160162

161-
async def traverse_graph_by_edge(
163+
async def traverse_graph_for_aggregated(
162164
llm_client: OpenAIModel,
163165
tokenizer: Tokenizer,
164166
graph_storage: NetworkXStorage,
@@ -251,7 +253,6 @@ async def _process_single_batch(
251253
qas = _post_process_synthetic_data(content)
252254

253255
if len(qas) == 0:
254-
print(content)
255256
logger.error(
256257
"Error occurred while processing batch, question or answer is None"
257258
)
@@ -307,7 +308,8 @@ async def _process_single_batch(
307308
return results
308309

309310

310-
async def traverse_graph_atomically(
311+
# pylint: disable=too-many-branches, too-many-statements
312+
async def traverse_graph_for_atomic(
311313
llm_client: OpenAIModel,
312314
tokenizer: Tokenizer,
313315
graph_storage: NetworkXStorage,
@@ -328,17 +330,28 @@ async def traverse_graph_atomically(
328330
:param max_concurrent
329331
:return: question and answer
330332
"""
331-
assert traverse_strategy.qa_form == "atomic"
332333

334+
assert traverse_strategy.qa_form == "atomic"
333335
semaphore = asyncio.Semaphore(max_concurrent)
334336

337+
def _parse_qa(qa: str) -> tuple:
338+
if "Question:" in qa and "Answer:" in qa:
339+
question = qa.split("Question:")[1].split("Answer:")[0].strip()
340+
answer = qa.split("Answer:")[1].strip()
341+
elif "问题:" in qa and "答案:" in qa:
342+
question = qa.split("问题:")[1].split("答案:")[0].strip()
343+
answer = qa.split("答案:")[1].strip()
344+
else:
345+
return None, None
346+
return question.strip('"'), answer.strip('"')
347+
335348
async def _generate_question(node_or_edge: tuple):
336349
if len(node_or_edge) == 2:
337350
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
338-
loss = node_or_edge[1]["loss"]
351+
loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
339352
else:
340353
des = node_or_edge[2]["description"]
341-
loss = node_or_edge[2]["loss"]
354+
loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
342355

343356
async with semaphore:
344357
try:
@@ -350,13 +363,8 @@ async def _generate_question(node_or_edge: tuple):
350363
)
351364
)
352365

353-
if "Question:" in qa and "Answer:" in qa:
354-
question = qa.split("Question:")[1].split("Answer:")[0].strip()
355-
answer = qa.split("Answer:")[1].strip()
356-
elif "问题:" in qa and "答案:" in qa:
357-
question = qa.split("问题:")[1].split("答案:")[0].strip()
358-
answer = qa.split("答案:")[1].strip()
359-
else:
366+
question, answer = _parse_qa(qa)
367+
if question is None or answer is None:
360368
return {}
361369

362370
question = question.strip('"')
@@ -386,16 +394,18 @@ async def _generate_question(node_or_edge: tuple):
386394
if "<SEP>" in node[1]["description"]:
387395
description_list = node[1]["description"].split("<SEP>")
388396
for item in description_list:
389-
tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
397+
tasks.append((node[0], {"description": item}))
398+
if "loss" in node[1]:
399+
tasks[-1][1]["loss"] = node[1]["loss"]
390400
else:
391401
tasks.append((node[0], node[1]))
392402
for edge in edges:
393403
if "<SEP>" in edge[2]["description"]:
394404
description_list = edge[2]["description"].split("<SEP>")
395405
for item in description_list:
396-
tasks.append(
397-
(edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
398-
)
406+
tasks.append((edge[0], edge[1], {"description": item}))
407+
if "loss" in edge[2]:
408+
tasks[-1][2]["loss"] = edge[2]["loss"]
399409
else:
400410
tasks.append((edge[0], edge[1], edge[2]))
401411

0 commit comments

Comments
 (0)