|
1 | 1 | import asyncio |
| 2 | + |
| 3 | +from numba.scripts.generate_lower_listing import description |
2 | 4 | from tqdm.asyncio import tqdm as tqdm_async |
3 | 5 |
|
4 | 6 | from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage |
@@ -296,3 +298,103 @@ async def _process_single_batch( |
296 | 298 | logger.error("Error occurred while processing batches: %s", e) |
297 | 299 |
|
298 | 300 | return results |
| 301 | + |
| 302 | + |
| 303 | +async def traverse_graph_atomically( |
| 304 | + llm_client: OpenAIModel, |
| 305 | + tokenizer: Tokenizer, |
| 306 | + graph_storage: NetworkXStorage, |
| 307 | + traverse_strategy: TraverseStrategy, |
| 308 | + text_chunks_storage: JsonKVStorage, |
| 309 | + max_concurrent: int = 1000 |
| 310 | +) -> dict: |
| 311 | + """ |
| 312 | + Traverse the graph atomicly |
| 313 | +
|
| 314 | + :param llm_client |
| 315 | + :param tokenizer |
| 316 | + :param graph_storage |
| 317 | + :param traverse_strategy |
| 318 | + :param text_chunks_storage |
| 319 | + :param max_concurrent |
| 320 | + :return: question and answer |
| 321 | + """ |
| 322 | + |
| 323 | + assert traverse_strategy.qa_form == "atomic" |
| 324 | + |
| 325 | + semaphore = asyncio.Semaphore(max_concurrent) |
| 326 | + |
| 327 | + async def _generate_question( |
| 328 | + node_or_edge: tuple |
| 329 | + ): |
| 330 | + if len(node_or_edge) == 2: |
| 331 | + des = node_or_edge[0] + ": " + node_or_edge[1]['description'] |
| 332 | + answer = node_or_edge[1]['description'] |
| 333 | + else: |
| 334 | + des = node_or_edge[2]['description'] |
| 335 | + answer = node_or_edge[2]['description'] |
| 336 | + |
| 337 | + async with semaphore: |
| 338 | + try: |
| 339 | + language = "Chinese" if detect_main_language(des) == "zh" else "English" |
| 340 | + question = await llm_client.generate_answer( |
| 341 | + QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format( |
| 342 | + answer=des |
| 343 | + ) |
| 344 | + ) |
| 345 | + if question.startswith("Question:"): |
| 346 | + question = question[len("Question:"):].strip() |
| 347 | + elif question.startswith("问题:"): |
| 348 | + question = question[len("问题:"):].strip() |
| 349 | + |
| 350 | + question = question.strip("\"") |
| 351 | + answer = answer.strip("\"") |
| 352 | + |
| 353 | + logger.info("Question: %s", question) |
| 354 | + logger.info("Answer: %s", answer) |
| 355 | + return { |
| 356 | + compute_content_hash(question): { |
| 357 | + "question": question, |
| 358 | + "answer": answer, |
| 359 | + "loss": -1, |
| 360 | + "difficulty": "medium" |
| 361 | + } |
| 362 | + } |
| 363 | + except Exception as e: # pylint: disable=broad-except |
| 364 | + logger.error("Error occurred while generating question: %s", e) |
| 365 | + return {} |
| 366 | + |
| 367 | + results = {} |
| 368 | + edges = list(await graph_storage.get_all_edges()) |
| 369 | + nodes = list(await graph_storage.get_all_nodes()) |
| 370 | + |
| 371 | + edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes) |
| 372 | + |
| 373 | + # TODO: 需要把node的name也加进去,或者只用edge,两种都试一下 |
| 374 | + tasks = [] |
| 375 | + # des中可能会有SEP分割符 |
| 376 | + for node in nodes: |
| 377 | + if "<SEP>" in node[1]['description']: |
| 378 | + description_list = node[1]['description'].split("<SEP>") |
| 379 | + for item in description_list: |
| 380 | + tasks.append((node[0], {"description": item})) |
| 381 | + else: |
| 382 | + tasks.append((node[0], node[1])) |
| 383 | + for edge in edges: |
| 384 | + if "<SEP>" in edge[2]['description']: |
| 385 | + description_list = edge[2]['description'].split("<SEP>") |
| 386 | + for item in description_list: |
| 387 | + tasks.append((edge[0], edge[1], {"description": item})) |
| 388 | + else: |
| 389 | + tasks.append((edge[0], edge[1], edge[2])) |
| 390 | + |
| 391 | + for result in tqdm_async( |
| 392 | + asyncio.as_completed([_generate_question(task) for task in tasks]), |
| 393 | + total=len(tasks), |
| 394 | + desc="Generating questions" |
| 395 | + ): |
| 396 | + try: |
| 397 | + results.update(await result) |
| 398 | + except Exception as e: # pylint: disable=broad-except |
| 399 | + logger.error("Error occurred while generating questions: %s", e) |
| 400 | + return results |
0 commit comments