|
| 1 | +import math |
| 2 | + |
| 3 | +import gradio as gr |
| 4 | + |
| 5 | +from graphgen.bases import BaseLLMWrapper |
| 6 | +from graphgen.models import JsonKVStorage, NetworkXStorage |
| 7 | +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT |
| 8 | +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy |
| 9 | + |
| 10 | + |
| 11 | +import math |
| 12 | +from collections.abc import Iterable |
| 13 | + |
| 14 | +import pandas as pd |
| 15 | + |
| 16 | +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper |
| 17 | +from graphgen.common import init_llm, init_storage |
| 18 | +from graphgen.models import NetworkXStorage, JsonKVStorage |
| 19 | +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT |
| 20 | +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy |
| 21 | + |
| 22 | + |
| 23 | +class JudgeService: |
| 24 | + """Service for judging graph edges and nodes using a trainee LLM.""" |
| 25 | + def __init__(self, working_dir: str = "cache"): |
| 26 | + self.llm_client: BaseLLMWrapper = init_llm("trainee") |
| 27 | + |
| 28 | + def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: |
| 29 | + return pd.DataFrame([{"status": "judging_completed"}]) |
| 30 | + |
| 31 | + def judge(self) -> Iterable[pd.DataFrame]: |
| 32 | + """ |
| 33 | + Judge the statements in the graph storage |
| 34 | +
|
| 35 | + :param re_judge: re-judge the relations |
| 36 | + :return: |
| 37 | + """ |
| 38 | + return |
| 39 | + |
| 40 | + |
| 41 | + |
| 42 | +# async def judge_statement( # pylint: disable=too-many-statements |
| 43 | +# trainee_llm_client: BaseLLMWrapper, |
| 44 | +# graph_storage: NetworkXStorage, |
| 45 | +# rephrase_storage: JsonKVStorage, |
| 46 | +# re_judge: bool = False, |
| 47 | +# progress_bar: gr.Progress = None, |
| 48 | +# ) -> NetworkXStorage: |
| 49 | +# """ |
| 50 | +# Get all edges and nodes and judge them |
| 51 | +# |
| 52 | +# :param trainee_llm_client: judge the statements to get comprehension loss |
| 53 | +# :param graph_storage: graph storage instance |
| 54 | +# :param rephrase_storage: rephrase storage instance |
| 55 | +# :param re_judge: re-judge the relations |
| 56 | +# :param progress_bar |
| 57 | +# :return: |
| 58 | +# """ |
| 59 | +# |
| 60 | +# async def _judge_single_relation( |
| 61 | +# edge: tuple, |
| 62 | +# ): |
| 63 | +# source_id = edge[0] |
| 64 | +# target_id = edge[1] |
| 65 | +# edge_data = edge[2] |
| 66 | +# |
| 67 | +# if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: |
| 68 | +# logger.debug( |
| 69 | +# "Edge %s -> %s already judged, loss: %s, skip", |
| 70 | +# source_id, |
| 71 | +# target_id, |
| 72 | +# edge_data["loss"], |
| 73 | +# ) |
| 74 | +# return source_id, target_id, edge_data |
| 75 | +# |
| 76 | +# description = edge_data["description"] |
| 77 | +# |
| 78 | +# try: |
| 79 | +# descriptions = rephrase_storage.get_by_id(description) |
| 80 | +# assert descriptions is not None |
| 81 | +# |
| 82 | +# judgements = [] |
| 83 | +# gts = [gt for _, gt in descriptions] |
| 84 | +# for description, gt in descriptions: |
| 85 | +# judgement = await trainee_llm_client.generate_topk_per_token( |
| 86 | +# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) |
| 87 | +# ) |
| 88 | +# judgements.append(judgement[0].top_candidates) |
| 89 | +# |
| 90 | +# loss = yes_no_loss_entropy(judgements, gts) |
| 91 | +# |
| 92 | +# logger.debug( |
| 93 | +# "Edge %s -> %s description: %s loss: %s", |
| 94 | +# source_id, |
| 95 | +# target_id, |
| 96 | +# description, |
| 97 | +# loss, |
| 98 | +# ) |
| 99 | +# |
| 100 | +# edge_data["loss"] = loss |
| 101 | +# except Exception as e: # pylint: disable=broad-except |
| 102 | +# logger.error( |
| 103 | +# "Error in judging relation %s -> %s: %s", source_id, target_id, e |
| 104 | +# ) |
| 105 | +# logger.info("Use default loss 0.1") |
| 106 | +# edge_data["loss"] = -math.log(0.1) |
| 107 | +# |
| 108 | +# graph_storage.update_edge(source_id, target_id, edge_data) |
| 109 | +# return source_id, target_id, edge_data |
| 110 | +# |
| 111 | +# edges = graph_storage.get_all_edges() |
| 112 | +# |
| 113 | +# await run_concurrent( |
| 114 | +# _judge_single_relation, |
| 115 | +# edges, |
| 116 | +# desc="Judging relations", |
| 117 | +# unit="relation", |
| 118 | +# progress_bar=progress_bar, |
| 119 | +# ) |
| 120 | +# |
| 121 | +# async def _judge_single_entity( |
| 122 | +# node: tuple, |
| 123 | +# ): |
| 124 | +# node_id = node[0] |
| 125 | +# node_data = node[1] |
| 126 | +# |
| 127 | +# if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: |
| 128 | +# logger.debug( |
| 129 | +# "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] |
| 130 | +# ) |
| 131 | +# return node_id, node_data |
| 132 | +# |
| 133 | +# description = node_data["description"] |
| 134 | +# |
| 135 | +# try: |
| 136 | +# descriptions = rephrase_storage.get_by_id(description) |
| 137 | +# assert descriptions is not None |
| 138 | +# |
| 139 | +# judgements = [] |
| 140 | +# gts = [gt for _, gt in descriptions] |
| 141 | +# for description, gt in descriptions: |
| 142 | +# judgement = await trainee_llm_client.generate_topk_per_token( |
| 143 | +# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) |
| 144 | +# ) |
| 145 | +# judgements.append(judgement[0].top_candidates) |
| 146 | +# |
| 147 | +# loss = yes_no_loss_entropy(judgements, gts) |
| 148 | +# |
| 149 | +# logger.debug("Node %s description: %s loss: %s", node_id, description, loss) |
| 150 | +# |
| 151 | +# node_data["loss"] = loss |
| 152 | +# except Exception as e: # pylint: disable=broad-except |
| 153 | +# logger.error("Error in judging entity %s: %s", node_id, e) |
| 154 | +# logger.error("Use default loss 0.1") |
| 155 | +# node_data["loss"] = -math.log(0.1) |
| 156 | +# |
| 157 | +# graph_storage.update_node(node_id, node_data) |
| 158 | +# return node_id, node_data |
| 159 | +# |
| 160 | +# nodes = graph_storage.get_all_nodes() |
| 161 | +# |
| 162 | +# await run_concurrent( |
| 163 | +# _judge_single_entity, |
| 164 | +# nodes, |
| 165 | +# desc="Judging entities", |
| 166 | +# unit="entity", |
| 167 | +# progress_bar=progress_bar, |
| 168 | +# ) |
| 169 | +# |
| 170 | +# return graph_storage |
0 commit comments