Skip to content

Commit f9d6dc3

Browse files
fix: fix output node
1 parent 06fc6e3 commit f9d6dc3

File tree

4 files changed

+47
-19
lines changed

4 files changed

+47
-19
lines changed

graphgen/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
271271

272272
for node in sorted_nodes:
273273
self._execute_node(node, initial_ds)
274+
if getattr(node, "save_output", False):
275+
self.datasets[node.id] = self.datasets[node.id].materialize()
274276

275277
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
276278
return {node.id: self.datasets[node.id] for node in output_nodes}

graphgen/models/evaluator/qa/mtld_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class MTLDEvaluator(BaseEvaluator):
1111

1212
def __init__(self, threshold: float = 0.72):
1313
self.nltk_helper = NLTKHelper()
14-
self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("english"))
15-
self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("chinese"))
14+
self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en"))
15+
self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh"))
1616
self.threshold = threshold
1717

1818
def evaluate(self, pair: QAPair) -> float:

graphgen/operators/evaluate/evaluate_service.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any
2+
23
import pandas as pd
34

45
from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair
@@ -18,21 +19,32 @@ def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwar
1819
self.metrics = metrics
1920
self.kwargs = kwargs
2021
self.evaluators = {}
22+
self._init_evaluators()
2123

2224
def _init_evaluators(self):
2325
for metric in self.metrics:
2426
if metric == "qa_length":
2527
from graphgen.models import LengthEvaluator
28+
2629
self.evaluators[metric] = LengthEvaluator()
2730
elif metric == "qa_mtld":
2831
from graphgen.models import MTLDEvaluator
29-
self.evaluators[metric] = MTLDEvaluator(self.kwargs.get("mtld_params", {}))
32+
33+
self.evaluators[metric] = MTLDEvaluator(
34+
**self.kwargs.get("mtld_params", {})
35+
)
3036
elif metric == "qa_reward_score":
3137
from graphgen.models import RewardEvaluator
32-
self.evaluators[metric] = RewardEvaluator(self.kwargs.get("reward_params", {}))
38+
39+
self.evaluators[metric] = RewardEvaluator(
40+
**self.kwargs.get("reward_params", {})
41+
)
3342
elif metric == "qa_uni_score":
3443
from graphgen.models import UniEvaluator
35-
self.evaluators[metric] = UniEvaluator(self.kwargs.get("uni_params", {}))
44+
45+
self.evaluators[metric] = UniEvaluator(
46+
**self.kwargs.get("uni_params", {})
47+
)
3648
else:
3749
raise ValueError(f"Unknown metric: {metric}")
3850

@@ -44,16 +56,13 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
4456
try:
4557
qa_pair = QAPair(
4658
question=str(item.get("question", "")),
47-
answer=str(item.get("answer", ""))
59+
answer=str(item.get("answer", "")),
4860
)
4961
if not qa_pair.question or not qa_pair.answer:
5062
self.logger.error("Empty question or answer, skipping.")
5163
return {}
5264
except Exception as e:
53-
self.logger.error(
54-
"Error in QAPair creation: %s",
55-
str(e)
56-
)
65+
self.logger.error("Error in QAPair creation: %s", str(e))
5766
return {}
5867

5968
for metric, evaluator in self.evaluators.items():
@@ -65,17 +74,33 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
6574
else:
6675
item[metric] = float(score)
6776
except Exception as e:
68-
self.logger.error(
69-
"Error in %s evaluation: %s",
70-
metric,
71-
str(e)
72-
)
77+
self.logger.error("Error in %s evaluation: %s", metric, str(e))
7378
item[metric] = None
79+
return item
80+
81+
@staticmethod
82+
def transform_messages_format(items: list[dict]) -> list[dict]:
83+
"""
84+
Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...]
85+
"""
86+
transformed = []
87+
for item in items:
88+
messages = item.get("messages", [])
89+
question = next(
90+
(m["content"] for m in messages if m.get("role") == "user"), ""
91+
)
92+
answer = next(
93+
(m["content"] for m in messages if m.get("role") == "assistant"), ""
94+
)
95+
96+
transformed.append({"question": question, "answer": answer})
97+
return transformed
7498

7599
def evaluate(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
76100
if not items:
77101
return []
78102

103+
items = self.transform_messages_format(items)
79104
results = run_concurrent(
80105
self._process_single,
81106
items,

graphgen/run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,19 @@ def main():
9191
results = engine.execute(ds)
9292

9393
for node_id, dataset in results.items():
94-
output_path = os.path.join(output_path, f"{node_id}")
95-
os.makedirs(output_path, exist_ok=True)
94+
logger.info("Saving results for node %s", node_id)
95+
node_output_path = os.path.join(output_path, f"{node_id}")
96+
os.makedirs(node_output_path, exist_ok=True)
9697
dataset.write_json(
97-
output_path,
98+
node_output_path,
9899
filename_provider=NodeFilenameProvider(node_id),
99100
pandas_json_args_fn=lambda: {
100101
"force_ascii": False,
101102
"orient": "records",
102103
"lines": True,
103104
},
104105
)
105-
logger.info("Node %s results saved to %s", node_id, output_path)
106+
logger.info("Node %s results saved to %s", node_id, node_output_path)
106107

107108
save_config(os.path.join(output_path, "config.yaml"), config)
108109
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

0 commit comments

Comments
 (0)