77import shutil
88import time
99import unittest
10- from copy import deepcopy
10+ from copy import copy , deepcopy
1111from datetime import datetime
1212from unittest import mock
1313
@@ -331,12 +331,15 @@ def test_trainer(self, mock_load):
331331 ),
332332 ]
333333 self .config .check_and_update ()
334+ old_taskset_path = self .config .stages [1 ].buffer .explorer_input .taskset .path
335+ self .config .stages [1 ].buffer .explorer_input .taskset .path = "/invalid/path"
334336
335- mock_load .return_value = self .config
337+ mock_load .return_value = copy . deepcopy ( self .config )
336338
337- run (config_path = "dummy.yaml" )
339+ with self .assertRaises (Exception ):
340+ run (config_path = "dummy.yaml" )
338341
339- stage_configs = [cfg .check_and_update () for cfg in self .config ]
342+ stage_configs = [cfg .check_and_update () for cfg in copy . deepcopy ( self .config ) ]
340343
341344 # sft warmup stage
342345 sft_config = stage_configs [0 ]
@@ -351,6 +354,10 @@ def test_trainer(self, mock_load):
351354 self .assertEqual (parser .metric_min_step (response_metrics [0 ]), 1 )
352355 self .assertEqual (parser .metric_max_step (response_metrics [0 ]), 3 )
353356
357+ self .config .stages [1 ].buffer .explorer_input .taskset .path = old_taskset_path
358+ mock_load .return_value = copy .deepcopy (self .config )
359+ run (config_path = "dummy.yaml" )
360+
354361 # grpo stage
355362 grpo_config = stage_configs [1 ]
356363 parser = TensorBoardParser (os .path .join (grpo_config .monitor .cache_dir , "tensorboard" ))
0 commit comments