Skip to content

Commit 60f4d1b

Browse files
feat: use output config instead of relying on leaf node type to save … (#139)
* feat: use output config instead of relying on leaf node type to save output * test: update e2e tests
1 parent 1c66e7e commit 60f4d1b

File tree

14 files changed

+49
-56
lines changed

14 files changed

+49
-56
lines changed

examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ nodes:
3030
execution_params:
3131
replicas: 1
3232
batch_size: 128
33+
save_output: true
3334
params:
3435
method: schema_guided
3536
schema_path: graphgen/templates/extraction/schemas/legal_contract.json

examples/generate/generate_aggregated_qa/aggregated_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ nodes:
7474
execution_params:
7575
replicas: 1
7676
batch_size: 128
77+
save_output: true # save output
7778
params:
7879
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
7980
data_format: ChatML # Alpaca, Sharegpt, ChatML

examples/generate/generate_atomic_qa/atomic_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ nodes:
5050
execution_params:
5151
replicas: 1
5252
batch_size: 128
53+
save_output: true
5354
params:
5455
method: atomic
5556
data_format: Alpaca

examples/generate/generate_cot_qa/cot_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ nodes:
5252
execution_params:
5353
replicas: 1
5454
batch_size: 128
55+
save_output: true
5556
params:
5657
method: cot
5758
data_format: Sharegpt

examples/generate/generate_multi_hop_qa/multi_hop_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ nodes:
5353
execution_params:
5454
replicas: 1
5555
batch_size: 128
56+
save_output: true
5657
params:
5758
method: multi_hop
5859
data_format: ChatML

examples/generate/generate_vqa/vqa_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ nodes:
5454
execution_params:
5555
replicas: 1
5656
batch_size: 128
57+
save_output: true
5758
params:
5859
method: vqa
5960
data_format: ChatML

graphgen/bases/datatypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ class Node(BaseModel):
6363
default_factory=list, description="list of dependent node ids"
6464
)
6565
execution_params: dict = Field(
66-
default_factory=dict, description="execution parameters like replicas, batch_size"
66+
default_factory=dict,
67+
description="execution parameters like replicas, batch_size",
68+
)
69+
save_output: bool = Field(
70+
default=False, description="whether to save the output of this node"
6771
)
6872

6973
@classmethod

graphgen/engine.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
import os
21
import inspect
32
import logging
3+
import os
44
from collections import defaultdict, deque
55
from functools import wraps
66
from typing import Any, Callable, Dict, List, Set
7-
from dotenv import load_dotenv
87

98
import ray
109
import ray.data
10+
from dotenv import load_dotenv
1111
from ray.data import DataContext
1212

1313
from graphgen.bases import Config, Node
14-
from graphgen.utils import logger
1514
from graphgen.common import init_llm, init_storage
15+
from graphgen.utils import logger
1616

1717
load_dotenv()
1818

19+
1920
class Engine:
2021
def __init__(
2122
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
@@ -42,7 +43,7 @@ def __init__(
4243
existing_env_vars = ray_init_kwargs["runtime_env"].get("env_vars", {})
4344
ray_init_kwargs["runtime_env"]["env_vars"] = {
4445
**all_env_vars,
45-
**existing_env_vars
46+
**existing_env_vars,
4647
}
4748

4849
if not ray.is_initialized():
@@ -265,24 +266,11 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
265266
f"Unsupported node type {node.type} for node {node.id}"
266267
)
267268

268-
@staticmethod
269-
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
270-
all_ids = {n.id for n in nodes}
271-
deps_set = set()
272-
for n in nodes:
273-
deps_set.update(n.dependencies)
274-
return all_ids - deps_set
275-
276269
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
277270
sorted_nodes = self._topo_sort(self.config.nodes)
278271

279272
for node in sorted_nodes:
280273
self._execute_node(node, initial_ds)
281274

282-
leaf_nodes = self._find_leaf_nodes(sorted_nodes)
283-
284-
@ray.remote
285-
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
286-
return ds.take_all()
287-
288-
return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
275+
output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)]
276+
return {node.id: self.datasets[node.id] for node in output_nodes}

tests/e2e_tests/conftest.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,59 +5,48 @@
55

66

77
def run_generate_test(tmp_path: Path, config_name: str):
8-
"""
9-
Run the generate test with the given configuration file and temporary path.
10-
11-
Args:
12-
tmp_path: pytest temporary path
13-
config_name: configuration file name (e.g. "atomic_config.yaml")
14-
15-
Returns:
16-
tuple: (run_folder, json_files[0])
17-
"""
188
repo_root = Path(__file__).resolve().parents[2]
199
os.chdir(repo_root)
2010

21-
config_path = repo_root / "graphgen" / "configs" / config_name
22-
output_dir = tmp_path / "output"
23-
output_dir.mkdir(parents=True, exist_ok=True)
11+
config_path = repo_root / config_name
2412

2513
result = subprocess.run(
2614
[
2715
"python",
2816
"-m",
29-
"graphgen.generate",
17+
"graphgen.run",
3018
"--config_file",
3119
str(config_path),
32-
"--output_dir",
33-
str(output_dir),
3420
],
3521
capture_output=True,
3622
text=True,
3723
check=False,
3824
)
3925
assert result.returncode == 0, f"Script failed with error: {result.stderr}"
4026

41-
data_root = output_dir / "data" / "graphgen"
42-
assert data_root.exists(), f"{data_root} does not exist"
43-
run_folders = sorted(data_root.iterdir(), key=lambda p: p.name, reverse=True)
44-
assert run_folders, f"No run folders found in {data_root}"
27+
run_root = repo_root / "cache" / "output"
28+
assert run_root.exists(), f"{run_root} does not exist"
29+
run_folders = sorted(
30+
[p for p in run_root.iterdir() if p.is_dir()], key=lambda p: p.name, reverse=True
31+
)
32+
assert run_folders, f"No run folders found in {run_root}"
4533
run_folder = run_folders[0]
4634

47-
config_saved = run_folder / "config.yaml"
48-
assert config_saved.exists(), f"{config_saved} not found"
35+
node_dirs = [p for p in run_folder.iterdir() if p.is_dir()]
36+
assert node_dirs, f"No node outputs found in {run_folder}"
4937

50-
json_files = list(run_folder.glob("*.json"))
51-
assert json_files, f"No JSON output found in {run_folder}"
38+
json_files = []
39+
for nd in node_dirs:
40+
json_files.extend(nd.glob("*.jsonl"))
41+
assert json_files, f"No JSONL output found under nodes in {run_folder}"
5242

53-
log_files = list(run_folder.glob("*.log"))
54-
assert log_files, "No log file generated"
43+
log_file = repo_root / "cache" / "logs" / "Driver.log"
44+
assert log_file.exists(), "No log file generated"
5545

5646
with open(json_files[0], "r", encoding="utf-8") as f:
57-
data = json.load(f)
58-
assert (
59-
isinstance(data, list) and len(data) > 0
60-
), "JSON output is empty or not a list"
47+
first_line = f.readline().strip()
48+
assert first_line, "JSONL output is empty"
49+
data = json.loads(first_line)
50+
assert isinstance(data, dict), "First JSONL record is not a dict"
6151

6252
return run_folder, json_files[0]
63-

tests/e2e_tests/test_generate_aggregated.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55

66
def test_generate_aggregated(tmp_path: Path):
7-
run_generate_test(tmp_path, "aggregated_config.yaml")
7+
run_generate_test(
8+
tmp_path, "examples/generate/generate_aggregated_qa/aggregated_config.yaml"
9+
)

0 commit comments

Comments
 (0)