1515import transformers
1616import yaml
1717
18+ from trinity .utils .dlc_utils import is_running , setup_ray_cluster , stop_ray_cluster
19+
1820# Default list of GPU counts to test
1921DEFAULT_GPU_NUMS : List [int ] = [1 , 2 , 4 , 6 ]
2022EXCEPTION_STRING = "Traceback (most recent call last)"
@@ -51,6 +53,9 @@ def monitor_output(
5153 if EXCEPTION_STRING in line :
5254 exception_event .set ()
5355
56+ if exception_event .is_set ():
57+ print (line , end = "" , flush = True )
58+
5459 # Check for oom
5560 if OOM_STRING in line :
5661 exception_event .set ()
@@ -64,6 +69,7 @@ def run_command_with_monitor(
6469 command : List [str ],
6570 envs : dict [str , str ],
6671 log_path : str ,
72+ checkpoint_path : str ,
6773 timeout : Optional [int ] = None ,
6874) -> bool :
6975 """Runs a shell command with real-time output monitoring and early termination support.
@@ -77,26 +83,27 @@ def run_command_with_monitor(
7783 command: Command to execute, as a list of strings.
7884 envs: Environment variables to set for the command.
7985 log_path: Path to the log file where output will be saved.
86+ checkpoint_path: Path to the checkpoint directory.
8087 timeout: Optional timeout in seconds before forcing termination.
8188
8289 Returns:
8390 True if the command completed successfully without OOM error; False otherwise.
8491 """
8592 retry_flag = True
8693 success_flag = False
87- checkpoint_root = os .environ .get ("TRINITY_CHECKPOINT_ROOT_DIR" , "./checkpoints/length-test" )
94+ envs ["TRINITY_CHECKPOINT_ROOT_DIR" ] = checkpoint_path
95+ process_env = os .environ .copy ()
96+ process_env .update (envs )
8897
8998 while retry_flag :
9099 # Clean up checkpoint directory before each run
91- shutil .rmtree (checkpoint_root , ignore_errors = True )
100+ shutil .rmtree (checkpoint_path , ignore_errors = True )
92101
93102 exception_event = threading .Event ()
94103 oom_event = threading .Event ()
95104
96105 with open (log_path , "w" , encoding = "utf-8" ) as log_file :
97106 # Start subprocess with merged stdout/stderr
98- process_env = os .environ .copy ()
99- process_env .update (envs )
100107 process = subprocess .Popen (
101108 command ,
102109 stdout = subprocess .PIPE ,
@@ -160,6 +167,7 @@ def run_command_with_monitor(
160167def find_max_model_len (
161168 model_path : str ,
162169 model_config ,
170+ checkpoint_path : str ,
163171 trainer_gpu_num : int ,
164172 sp_num : int ,
165173 base_log_dir : str ,
@@ -178,6 +186,7 @@ def find_max_model_len(
178186 Args:
179187 model_path: Path to the pretrained model.
180188 model_config: Loaded Hugging Face model configuration.
189+ checkpoint_path: Path to the checkpoint directory.
181190 trainer_gpu_num: Number of GPUs allocated.
182191 sp_num: Number of sequence parallel groups.
183192 base_log_dir: Base directory for saving logs.
@@ -253,6 +262,7 @@ def find_max_model_len(
253262 cmd_base ,
254263 cmd_env ,
255264 logfile ,
265+ checkpoint_path ,
256266 timeout = timeout ,
257267 )
258268
@@ -278,6 +288,13 @@ def find_max_model_len(
278288
279289def main (args ):
280290 """Main entry point: orchestrates multi-GPU, multi-SP context length testing."""
291+ if args .dlc :
292+ cluster_namespace = "search_context_length_capacity"
293+ setup_ray_cluster (namespace = cluster_namespace )
294+
295+ if not is_running ():
296+ raise RuntimeError ("Ray is not running, please start it by `ray start --head`." )
297+
281298 os .makedirs (args .log_dir , exist_ok = True )
282299
283300 model_name = os .path .basename (args .model_path )
@@ -300,6 +317,7 @@ def main(args):
300317 max_length = find_max_model_len (
301318 model_path = args .model_path ,
302319 model_config = model_config ,
320+ checkpoint_path = args .checkpoint_path ,
303321 trainer_gpu_num = trainer_gpu_num ,
304322 sp_num = sp_num ,
305323 base_log_dir = args .log_dir ,
@@ -319,6 +337,9 @@ def main(args):
319337 f"max_model_len = { max_length } "
320338 )
321339
340+ if args .dlc :
341+ stop_ray_cluster (namespace = cluster_namespace )
342+
322343
323344if __name__ == "__main__" :
324345 default_log_dir = os .path .join (os .path .dirname (__file__ ), "logs" )
@@ -343,6 +364,14 @@ def main(args):
343364 default = default_log_dir ,
344365 help = "Directory to store experiment logs." ,
345366 )
367+ parser .add_argument (
368+ "--checkpoint_path" ,
369+ type = str ,
370+ default = os .environ .get ("TRINITY_CHECKPOINT_ROOT_DIR" , "./checkpoints/length-test" ),
371+ help = "Checkpoint path for testing. "
372+ "Note that this directory will be deleted during the test, "
373+ "please specify a path that is not used by other processes." ,
374+ )
346375 parser .add_argument (
347376 "--test_gpu_num" ,
348377 type = int ,
@@ -387,6 +416,9 @@ def main(args):
387416 default = 2400 ,
388417 help = "Timeout for each experiment in seconds." ,
389418 )
419+ parser .add_argument (
420+ "--dlc" , action = "store_true" , help = "Specify when running in Aliyun PAI DLC."
421+ )
390422
391423 args = parser .parse_args ()
392424 main (args )
0 commit comments