Skip to content

Commit d460a2a

Browse files
feat: write results in output folder
1 parent 244deb4 commit d460a2a

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

graphgen/engine.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
198198
deps_set.update(n.dependencies)
199199
return all_ids - deps_set
200200

201-
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, List[Any]]:
201+
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
202202
sorted_nodes = self._topo_sort(self.config.nodes)
203203

204204
for node in sorted_nodes:
@@ -210,7 +210,4 @@ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, List[Any]]:
210210
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
211211
return ds.take_all()
212212

213-
results = ray.get(
214-
[_fetch_result.remote(self.datasets[node_id]) for node_id in leaf_nodes]
215-
)
216-
return dict(zip(leaf_nodes, results))
213+
return {node_id: self.datasets[node_id] for node_id in leaf_nodes}

graphgen/run.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import argparse
22
import os
33
import time
4-
from importlib.resources import files
4+
from importlib import resources
5+
from typing import Any, Dict
56

7+
import ray
68
import yaml
79
from dotenv import load_dotenv
10+
from ray.data.block import Block
11+
from ray.data.datasource.filename_provider import FilenameProvider
812

9-
from graphgen.engine import Context, Engine, collect_ops
10-
from graphgen.graphgen import GraphGen
13+
from graphgen.engine import Engine
14+
from graphgen.operators import operators
1115
from graphgen.utils import logger, set_logger
1216

1317
sys_path = os.path.abspath(os.path.dirname(__file__))
@@ -28,12 +32,38 @@ def save_config(config_path, global_config):
2832
)
2933

3034

35+
class NodeFilenameProvider(FilenameProvider):
36+
def __init__(self, node_id: str):
37+
self.node_id = node_id
38+
39+
def get_filename_for_block(
40+
self, block: Block, write_uuid: str, task_index: int, block_index: int
41+
) -> str:
42+
# format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
43+
return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl"
44+
45+
def get_filename_for_row(
46+
self,
47+
row: Dict[str, Any],
48+
write_uuid: str,
49+
task_index: int,
50+
block_index: int,
51+
row_index: int,
52+
) -> str:
53+
raise NotImplementedError(
54+
f"Row-based filenames are not supported by write_json. "
55+
f"Node: {self.node_id}, write_uuid: {write_uuid}"
56+
)
57+
58+
3159
def main():
3260
parser = argparse.ArgumentParser()
3361
parser.add_argument(
3462
"--config_file",
3563
help="Config parameters for GraphGen.",
36-
default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
64+
default=resources.files("graphgen")
65+
.joinpath("configs")
66+
.joinpath("aggregated_config.yaml"),
3767
type=str,
3868
)
3969
parser.add_argument(
@@ -51,6 +81,8 @@ def main():
5181
with open(args.config_file, "r", encoding="utf-8") as f:
5282
config = yaml.load(f, Loader=yaml.FullLoader)
5383

84+
engine = Engine(config, operators)
85+
5486
unique_id = int(time.time())
5587

5688
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
@@ -65,15 +97,22 @@ def main():
6597
unique_id,
6698
os.path.join(working_dir, f"{unique_id}.log"),
6799
)
68-
69-
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
70-
71-
# share context between different steps
72-
ctx = Context(config=config, graph_gen=graph_gen)
73-
ops = collect_ops(config, graph_gen)
74-
75-
# run operations
76-
Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
100+
ds = ray.data.from_items([])
101+
results = engine.execute(ds)
102+
103+
for node_id, dataset in results.items():
104+
output_path = os.path.join(output_path, f"{node_id}")
105+
os.makedirs(output_path, exist_ok=True)
106+
dataset.write_json(
107+
output_path,
108+
filename_provider=NodeFilenameProvider(node_id),
109+
pandas_json_args_fn=lambda: {
110+
"force_ascii": False,
111+
"orient": "records",
112+
"lines": True,
113+
},
114+
)
115+
logger.info("Node %s results saved to %s", node_id, output_path)
77116

78117
save_config(os.path.join(output_path, "config.yaml"), config)
79118
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

0 commit comments

Comments
 (0)