11import argparse
22import os
33import time
4- from importlib .resources import files
4+ from importlib import resources
5+ from typing import Any , Dict
56
7+ import ray
68import yaml
79from dotenv import load_dotenv
10+ from ray .data .block import Block
11+ from ray .data .datasource .filename_provider import FilenameProvider
812
9- from graphgen .engine import Context , Engine , collect_ops
10- from graphgen .graphgen import GraphGen
13+ from graphgen .engine import Engine
14+ from graphgen .operators import operators
1115from graphgen .utils import logger , set_logger
1216
1317sys_path = os .path .abspath (os .path .dirname (__file__ ))
@@ -28,12 +32,38 @@ def save_config(config_path, global_config):
2832 )
2933
3034
35+ class NodeFilenameProvider (FilenameProvider ):
36+ def __init__ (self , node_id : str ):
37+ self .node_id = node_id
38+
39+ def get_filename_for_block (
40+ self , block : Block , write_uuid : str , task_index : int , block_index : int
41+ ) -> str :
42+ # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json
43+ return f"{ self .node_id } _{ write_uuid } _{ task_index :06d} _{ block_index :06d} .jsonl"
44+
45+ def get_filename_for_row (
46+ self ,
47+ row : Dict [str , Any ],
48+ write_uuid : str ,
49+ task_index : int ,
50+ block_index : int ,
51+ row_index : int ,
52+ ) -> str :
53+ raise NotImplementedError (
54+ f"Row-based filenames are not supported by write_json. "
55+ f"Node: { self .node_id } , write_uuid: { write_uuid } "
56+ )
57+
58+
3159def main ():
3260 parser = argparse .ArgumentParser ()
3361 parser .add_argument (
3462 "--config_file" ,
3563 help = "Config parameters for GraphGen." ,
36- default = files ("graphgen" ).joinpath ("configs" , "aggregated_config.yaml" ),
64+ default = resources .files ("graphgen" )
65+ .joinpath ("configs" )
66+ .joinpath ("aggregated_config.yaml" ),
3767 type = str ,
3868 )
3969 parser .add_argument (
@@ -51,6 +81,8 @@ def main():
5181 with open (args .config_file , "r" , encoding = "utf-8" ) as f :
5282 config = yaml .load (f , Loader = yaml .FullLoader )
5383
84+ engine = Engine (config , operators )
85+
5486 unique_id = int (time .time ())
5587
5688 output_path = os .path .join (working_dir , "data" , "graphgen" , f"{ unique_id } " )
@@ -65,15 +97,22 @@ def main():
6597 unique_id ,
6698 os .path .join (working_dir , f"{ unique_id } .log" ),
6799 )
68-
69- graph_gen = GraphGen (unique_id = unique_id , working_dir = working_dir )
70-
71- # share context between different steps
72- ctx = Context (config = config , graph_gen = graph_gen )
73- ops = collect_ops (config , graph_gen )
74-
75- # run operations
76- Engine (max_workers = config .get ("max_workers" , 4 )).run (ops , ctx )
100+ ds = ray .data .from_items ([])
101+ results = engine .execute (ds )
102+
103+ for node_id , dataset in results .items ():
104+ output_path = os .path .join (output_path , f"{ node_id } " )
105+ os .makedirs (output_path , exist_ok = True )
106+ dataset .write_json (
107+ output_path ,
108+ filename_provider = NodeFilenameProvider (node_id ),
109+ pandas_json_args_fn = lambda : {
110+ "force_ascii" : False ,
111+ "orient" : "records" ,
112+ "lines" : True ,
113+ },
114+ )
115+ logger .info ("Node %s results saved to %s" , node_id , output_path )
77116
78117 save_config (os .path .join (output_path , "config.yaml" ), config )
79118 logger .info ("GraphGen completed successfully. Data saved to %s" , output_path )
0 commit comments