8181}
8282
8383
84- def setup_env (tempdir ):
85- os . environ [ "TRAINING_SCRIPT" ] = SCRIPT
86- os . environ [ "PYTHONPATH" ] = "./:$PYTHONPATH"
87- os . environ [ "TERMINATION_LOG_FILE" ] = tempdir + "/ termination-log"
84+ def setup_env (monkeypatch , tempdir ):
85+ monkeypatch . setenv ( "TRAINING_SCRIPT" , SCRIPT )
86+ monkeypatch . setenv ( "PYTHONPATH" , "./:$PYTHONPATH" )
87+ monkeypatch . setenv ( "TERMINATION_LOG_FILE" , os . path . join ( tempdir , " termination-log"))
8888
8989
90- def cleanup_env ():
91- os .environ .pop ("TRAINING_SCRIPT" , None )
92- os .environ .pop ("PYTHONPATH" , None )
93- os .environ .pop ("TERMINATION_LOG_FILE" , None )
94-
95-
96- def test_successful_ft ():
90+ def test_successful_ft (monkeypatch ):
9791 """Check if we can bootstrap and fine tune causallm models"""
9892 with tempfile .TemporaryDirectory () as tempdir :
99- setup_env (tempdir )
93+ setup_env (monkeypatch , tempdir )
10094 TRAIN_KWARGS = {** BASE_KWARGS , ** {"output_dir" : tempdir }}
10195 serialized_args = serialize_args (TRAIN_KWARGS )
102- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
96+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
10397
10498 assert main () == 0
10599 _validate_termination_files_when_tuning_succeeds (tempdir )
@@ -108,43 +102,43 @@ def test_successful_ft():
108102
109103
110104@pytest .mark .skipif (True , reason = "This test is deprecated so always skipped" )
111- def test_successful_pt ():
105+ def test_successful_pt (monkeypatch ):
112106 """Check if we can bootstrap and peft tune causallm models"""
113107 with tempfile .TemporaryDirectory () as tempdir :
114- setup_env (tempdir )
108+ setup_env (monkeypatch , tempdir )
115109 TRAIN_KWARGS = {** BASE_PEFT_KWARGS , ** {"output_dir" : tempdir }}
116110 serialized_args = serialize_args (TRAIN_KWARGS )
117- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
111+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
118112
119113 assert main () == 0
120114 _validate_termination_files_when_tuning_succeeds (tempdir )
121115 checkpoint = os .path .join (tempdir , get_highest_checkpoint (tempdir ))
122116 _validate_training_output (checkpoint , "pt" )
123117
124118
125- def test_successful_lora ():
119+ def test_successful_lora (monkeypatch ):
126120 """Check if we can bootstrap and LoRA tune causallm models"""
127121 with tempfile .TemporaryDirectory () as tempdir :
128- setup_env (tempdir )
122+ setup_env (monkeypatch , tempdir )
129123 TRAIN_KWARGS = {** BASE_LORA_KWARGS , ** {"output_dir" : tempdir }}
130124 serialized_args = serialize_args (TRAIN_KWARGS )
131- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
125+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
132126
133127 assert main () == 0
134128 _validate_termination_files_when_tuning_succeeds (tempdir )
135129 checkpoint = os .path .join (tempdir , get_highest_checkpoint (tempdir ))
136130 _validate_training_output (checkpoint , "lora" )
137131
138132
139- def test_lora_save_model_dir_separate_dirs ():
133+ def test_lora_save_model_dir_separate_dirs (monkeypatch ):
140134 """Run LoRA tuning with separate save_model_dir and output_dir.
141135 Verify model saved to save_model_dir and checkpoints saved to
142136 output_dir.
143137 """
144138 with tempfile .TemporaryDirectory () as tempdir :
145139 output_dir = os .path .join (tempdir , "output_dir" )
146140 save_model_dir = os .path .join (tempdir , "save_model_dir" )
147- setup_env (tempdir )
141+ setup_env (monkeypatch , tempdir )
148142 TRAIN_KWARGS = {
149143 ** BASE_LORA_KWARGS ,
150144 ** {
@@ -154,7 +148,7 @@ def test_lora_save_model_dir_separate_dirs():
154148 },
155149 }
156150 serialized_args = serialize_args (TRAIN_KWARGS )
157- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
151+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
158152
159153 assert main () == 0
160154 _validate_termination_files_when_tuning_succeeds (output_dir )
@@ -165,12 +159,12 @@ def test_lora_save_model_dir_separate_dirs():
165159 assert len (checkpoints ) == 1
166160
167161
168- def test_lora_save_model_dir_same_dir_as_output_dir ():
162+ def test_lora_save_model_dir_same_dir_as_output_dir (monkeypatch ):
169163 """Run LoRA tuning with same save_model_dir and output_dir.
170164 Verify checkpoints, logs, and model saved to path.
171165 """
172166 with tempfile .TemporaryDirectory () as tempdir :
173- setup_env (tempdir )
167+ setup_env (monkeypatch , tempdir )
174168 TRAIN_KWARGS = {
175169 ** BASE_LORA_KWARGS ,
176170 ** {
@@ -180,7 +174,7 @@ def test_lora_save_model_dir_same_dir_as_output_dir():
180174 },
181175 }
182176 serialized_args = serialize_args (TRAIN_KWARGS )
183- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
177+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
184178
185179 assert main () == 0
186180 # check logs, checkpoint dir, and model exists in path
@@ -195,19 +189,21 @@ def test_lora_save_model_dir_same_dir_as_output_dir():
195189 assert len (checkpoints ) == TRAIN_KWARGS ["num_train_epochs" ]
196190
197191
198- def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no ():
192+ def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no (
193+ monkeypatch ,
194+ ):
199195 """Run LoRA tuning with same save_model_dir and output_dir and
200196 save_strategy=no. Verify no checkpoints created, only
201197 logs and final model.
202198 """
203199 with tempfile .TemporaryDirectory () as tempdir :
204- setup_env (tempdir )
200+ setup_env (monkeypatch , tempdir )
205201 TRAIN_KWARGS = {
206202 ** BASE_LORA_KWARGS ,
207203 ** {"output_dir" : tempdir , "save_model_dir" : tempdir , "save_strategy" : "no" },
208204 }
209205 serialized_args = serialize_args (TRAIN_KWARGS )
210- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
206+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
211207
212208 assert main () == 0
213209 # check that model and logs exists in output_dir
@@ -219,9 +215,9 @@ def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no():
219215 assert len (checkpoints ) == 0
220216
221217
222- def test_lora_with_lora_post_process_for_vllm_set_to_true ():
218+ def test_lora_with_lora_post_process_for_vllm_set_to_true (monkeypatch ):
223219 with tempfile .TemporaryDirectory () as tempdir :
224- setup_env (tempdir )
220+ setup_env (monkeypatch , tempdir )
225221 TRAIN_KWARGS = {
226222 ** BASE_LORA_KWARGS ,
227223 ** {
@@ -231,7 +227,7 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
231227 },
232228 }
233229 serialized_args = serialize_args (TRAIN_KWARGS )
234- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
230+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
235231
236232 assert main () == 0
237233 # check that model and logs exists in output_dir
@@ -255,9 +251,9 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
255251 not _is_package_available ("HFResourceScanner" ),
256252 reason = "Only runs if HFResourceScanner is installed" ,
257253)
258- def test_launch_with_HFResourceScanner_enabled ():
254+ def test_launch_with_HFResourceScanner_enabled (monkeypatch ):
259255 with tempfile .TemporaryDirectory () as tempdir :
260- setup_env (tempdir )
256+ setup_env (monkeypatch , tempdir )
261257 scanner_outfile = os .path .join (tempdir , TrackerConfigs .scanner_output_filename )
262258 TRAIN_KWARGS = {
263259 ** BASE_LORA_KWARGS ,
@@ -271,7 +267,7 @@ def test_launch_with_HFResourceScanner_enabled():
271267 },
272268 }
273269 serialized_args = serialize_args (TRAIN_KWARGS )
274- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
270+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
275271
276272 assert main () == 0
277273 assert os .path .exists (scanner_outfile ) is True
@@ -281,14 +277,14 @@ def test_launch_with_HFResourceScanner_enabled():
281277 assert scanner_res ["mem_data" ] is not None
282278
283279
284- def test_bad_script_path ():
280+ def test_bad_script_path (monkeypatch ):
285281 """Check for appropriate error for an invalid training script location"""
286282 with tempfile .TemporaryDirectory () as tempdir :
287- setup_env (tempdir )
283+ setup_env (monkeypatch , tempdir )
288284 TRAIN_KWARGS = {** BASE_LORA_KWARGS , ** {"output_dir" : tempdir }}
289285 serialized_args = serialize_args (TRAIN_KWARGS )
290- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
291- os . environ [ "TRAINING_SCRIPT" ] = "/not/here"
286+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
287+ monkeypatch . setenv ( "TRAINING_SCRIPT" , "/not/here" )
292288
293289 with pytest .raises (SystemExit ) as pytest_wrapped_e :
294290 main ()
@@ -297,61 +293,61 @@ def test_bad_script_path():
297293 assert os .stat (tempdir + "/termination-log" ).st_size > 0
298294
299295
300- def test_blank_env_var ():
296+ def test_blank_env_var (monkeypatch ):
301297 with tempfile .TemporaryDirectory () as tempdir :
302- setup_env (tempdir )
303- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = ""
298+ setup_env (monkeypatch , tempdir )
299+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , "" )
304300 with pytest .raises (SystemExit ) as pytest_wrapped_e :
305301 main ()
306302 assert pytest_wrapped_e .type == SystemExit
307303 assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
308304 assert os .stat (tempdir + "/termination-log" ).st_size > 0
309305
310306
311- def test_faulty_file_path ():
307+ def test_faulty_file_path (monkeypatch ):
312308 with tempfile .TemporaryDirectory () as tempdir :
313- setup_env (tempdir )
309+ setup_env (monkeypatch , tempdir )
314310 faulty_path = os .path .join (tempdir , "non_existent_file.pkl" )
315311 TRAIN_KWARGS = {
316312 ** BASE_LORA_KWARGS ,
317313 ** {"training_data_path" : faulty_path , "output_dir" : tempdir },
318314 }
319315 serialized_args = serialize_args (TRAIN_KWARGS )
320- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
316+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
321317 with pytest .raises (SystemExit ) as pytest_wrapped_e :
322318 main ()
323319 assert pytest_wrapped_e .type == SystemExit
324320 assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
325321 assert os .stat (tempdir + "/termination-log" ).st_size > 0
326322
327323
328- def test_bad_base_model_path ():
324+ def test_bad_base_model_path (monkeypatch ):
329325 with tempfile .TemporaryDirectory () as tempdir :
330- setup_env (tempdir )
326+ setup_env (monkeypatch , tempdir )
331327 TRAIN_KWARGS = {
332328 ** BASE_LORA_KWARGS ,
333329 ** {"model_name_or_path" : "/wrong/path" },
334330 "output_dir" : tempdir ,
335331 }
336332 serialized_args = serialize_args (TRAIN_KWARGS )
337- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
333+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
338334 with pytest .raises (SystemExit ) as pytest_wrapped_e :
339335 main ()
340336 assert pytest_wrapped_e .type == SystemExit
341337 assert pytest_wrapped_e .value .code == USER_ERROR_EXIT_CODE
342338 assert os .stat (tempdir + "/termination-log" ).st_size > 0
343339
344340
345- def test_config_parsing_error ():
341+ def test_config_parsing_error (monkeypatch ):
346342 with tempfile .TemporaryDirectory () as tempdir :
347- setup_env (tempdir )
343+ setup_env (monkeypatch , tempdir )
348344 TRAIN_KWARGS = {
349345 ** BASE_LORA_KWARGS ,
350346 ** {"num_train_epochs" : "five" },
351347 "output_dir" : tempdir ,
352348 } # Intentional type error
353349 serialized_args = serialize_args (TRAIN_KWARGS )
354- os . environ [ "SFT_TRAINER_CONFIG_JSON_ENV_VAR" ] = serialized_args
350+ monkeypatch . setenv ( "SFT_TRAINER_CONFIG_JSON_ENV_VAR" , serialized_args )
355351 with pytest .raises (SystemExit ) as pytest_wrapped_e :
356352 main ()
357353 assert pytest_wrapped_e .type == SystemExit
@@ -376,9 +372,3 @@ def _validate_training_output(base_dir, tuning_technique):
376372 else :
377373 assert os .path .exists (base_dir + "/adapter_config.json" ) is True
378374 assert os .path .exists (base_dir + "/adapter_model.safetensors" ) is True
379-
380-
381- def test_cleanup ():
382- # This runs to unset env variables that could disrupt other tests
383- cleanup_env ()
384- assert True
0 commit comments