11import argparse
2+ import importlib
23import os
34import subprocess
5+ import sys
46import time
57
68import torch
79import torch .distributed as dist
810import yaml
911
1012from trinity .algorithm .algorithm import ALGORITHM_TYPE
11- from trinity .common .constants import MODEL_PATH_ENV_VAR
13+ from trinity .common .constants import MODEL_PATH_ENV_VAR , SyncStyle
1214from trinity .utils .dlc_utils import get_dlc_env_vars
1315
1416
1517def set_engine_num (config , args ):
1618 config ["cluster" ]["node_num" ] = args .node_num
1719 config ["cluster" ]["gpu_per_node" ] = args .gpu_per_node
18- batch_size = config ["buffer" ]["batch_size" ]
20+ batch_size = config ["buffer" ]["batch_size" ] * config [ "algorithm" ][ "repeat_times" ]
1921 if config ["mode" ] == "train" :
2022 return
2123
@@ -61,6 +63,84 @@ def update_opt_explorer_num(trainer_gpu_num, opt_explorer_num, opt_ratio_diff):
6163 config ["explorer" ]["rollout_model" ]["engine_num" ] = opt_explorer_num
6264
6365
66+ def check_taskset_path (dataset_name : str , taskset_path : str ) -> str :
67+ """Ensures the taskset path exists for the given dataset; generates it if necessary.
68+
69+ This function checks whether the 'path' specified in taskset_config exists. If not,
70+ it uses a corresponding data generation script (e.g., gen_countdown_data.py) to create
71+ the dataset at the default or provided location. The generator scripts are expected
72+ to be located in the 'scripts/' subdirectory relative to this file.
73+
74+ Args:
75+ dataset_name: Name of the dataset (e.g., "countdown", "guru").
76+ Must be one of the supported datasets defined in `dataset_script_map`.
77+ taskset_path: Path to the dataset.
78+
79+ Returns:
80+ str: The resolved path to the dataset.
81+
82+ Raises:
83+ ValueError: If the `dataset_name` is not supported.
84+ FileNotFoundError: If the corresponding generator script does not exist.
85+ ImportError: If the generator module fails to load.
86+ AttributeError: If the loaded module does not define 'DEFAULT_DATA_PATH'.
87+ subprocess.CalledProcessError: If the generation script fails (due to check=True).
88+
89+ Side Effects:
90+ - Modifies `taskset_config` by setting the "path" key to the resolved path.
91+ - May create directories and files on disk via the external generation script.
92+ - Executes a subprocess to run the dataset generation script.
93+
94+ Examples:
95+ For dataset_name='guru' and taskset_config={"path": None},
96+ this function will runs the following command and
97+ generate the guru dataset to default location (DEFAULT_DATA_PATH in scripts/gen_guru_data.py):
98+
99+ ```bash
100+ python scripts/gen_guru_data.py --local_dir DEFAULT_DATA_PATH
101+ ```
102+ """
103+ if taskset_path :
104+ if os .path .exists (taskset_path ):
105+ return taskset_path
106+ if dataset_name == "gsm8k" and taskset_path == "openai/gsm8k" :
107+ return taskset_path
108+
109+ dataset_script_map = {
110+ "countdown" : "gen_countdown_data.py" ,
111+ "guru" : "gen_guru_data.py" ,
112+ }
113+ if dataset_name not in dataset_script_map :
114+ raise ValueError (
115+ f"Unsupported dataset: { dataset_name } . Please specify a valid taskset path."
116+ )
117+
118+ base_dir = os .path .dirname (__file__ )
119+ script_filename = dataset_script_map [dataset_name ]
120+ script_module_name = script_filename [:- 3 ] # remove .py
121+
122+ script_file_path = os .path .join (base_dir , "scripts" , script_filename )
123+ if not os .path .exists (script_file_path ):
124+ raise FileNotFoundError (f"Generator script not found: { script_file_path } " )
125+
126+ spec = importlib .util .spec_from_file_location (script_module_name , script_file_path )
127+ if spec is None or spec .loader is None :
128+ raise ImportError (f"Could not load spec for module: { script_module_name } " )
129+ module = importlib .util .module_from_spec (spec )
130+ spec .loader .exec_module (module )
131+
132+ if taskset_path is None :
133+ if not hasattr (module , "DEFAULT_DATA_PATH" ):
134+ raise AttributeError (f"{ script_filename } is missing 'DEFAULT_DATA_PATH'" )
135+ taskset_path = module .DEFAULT_DATA_PATH
136+ taskset_path = os .path .realpath (taskset_path )
137+
138+ gen_script_path = os .path .join (base_dir , "scripts" , script_filename )
139+ subprocess .run ([sys .executable , gen_script_path , "--local_dir" , taskset_path ], check = True )
140+
141+ return taskset_path
142+
143+
64144def prepare_configs (args , rank , current_time ):
65145 base_path = os .path .dirname (os .path .abspath (__file__ ))
66146
@@ -89,18 +169,19 @@ def prepare_configs(args, rank, current_time):
89169 )
90170 if args .critic_lr :
91171 config ["trainer" ]["trainer_config" ]["critic" ]["optim" ]["lr" ] = args .critic_lr
92- config ["buffer" ]["explorer_input" ]["taskset" ][ "path" ] = (
93- args . taskset_path
94- or os . environ . get ( "TASKSET_PATH" )
95- or config [ "buffer" ][ "explorer_input" ][ "taskset" ][ " path" ]
172+ taskset_config = config ["buffer" ]["explorer_input" ]["taskset" ]
173+ taskset_config [ "path" ] = check_taskset_path (
174+ args . dataset ,
175+ args . taskset_path or os . environ . get ( "TASKSET_PATH" ) or taskset_config [ " path" ],
96176 )
97- assert (
98- config ["buffer" ]["explorer_input" ]["taskset" ]["path" ] is not None
99- ), "Please specify taskset path."
100177 if args .lr :
101178 config ["algorithm" ]["optimizer" ]["lr" ] = args .lr
102179 if args .sync_interval :
103180 config ["synchronizer" ]["sync_interval" ] = args .sync_interval
181+ if args .sync_offset :
182+ config ["synchronizer" ]["sync_offset" ] = args .sync_offset
183+ if args .sync_style :
184+ config ["synchronizer" ]["sync_style" ] = args .sync_style
104185
105186 with open (config_path , "w" ) as f :
106187 yaml .dump (config , f , allow_unicode = True , sort_keys = False )
@@ -131,7 +212,7 @@ def main(args):
131212 rank , current_time = 0 , time .time ()
132213 config_path = prepare_configs (args , rank , current_time )
133214 cmd_list = [
134- "python" ,
215+ sys . executable ,
135216 "-m" ,
136217 "trinity.cli.launcher" ,
137218 "run" ,
@@ -142,12 +223,16 @@ def main(args):
142223 dist .barrier ()
143224 dist .destroy_process_group ()
144225 cmd_list .append ("--dlc" )
226+ if args .dataset == "guru" :
227+ base_path = os .path .dirname (os .path .abspath (__file__ ))
228+ cmd_list .append ("--plugin-dir" )
229+ cmd_list .append (os .path .join (base_path , "plugins" ))
145230 subprocess .run (cmd_list , check = True )
146231
147232
148233if __name__ == "__main__" :
149234 parser = argparse .ArgumentParser ()
150- parser .add_argument ("dataset" , type = str , choices = ["gsm8k" , "countdown" , "openr1 " ])
235+ parser .add_argument ("dataset" , type = str . lower , choices = ["gsm8k" , "countdown" , "guru " ])
151236 parser .add_argument (
152237 "--dlc" , action = "store_true" , help = "Specify when running in Aliyun PAI DLC."
153238 )
@@ -191,5 +276,12 @@ def main(args):
191276 parser .add_argument (
192277 "--sync_interval" , type = int , default = None , help = "Specify the sync interval."
193278 )
279+ parser .add_argument ("--sync_offset" , type = int , default = None , help = "Specify the sync offset." )
280+ parser .add_argument (
281+ "--sync_style" ,
282+ type = str ,
283+ default = None ,
284+ choices = [sync_style .value for sync_style in SyncStyle ],
285+ )
194286 args = parser .parse_args ()
195287 main (args )
0 commit comments