Skip to content

Commit ec5c79c

Browse files
committed
Added monkey patching
Signed-off-by: romit <[email protected]>
1 parent a86c868 commit ec5c79c

File tree

4 files changed

+65
-65
lines changed

4 files changed

+65
-65
lines changed

tests/build/test_launch_script.py

Lines changed: 46 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -81,25 +81,19 @@
8181
}
8282

8383

84-
def setup_env(tempdir):
85-
os.environ["TRAINING_SCRIPT"] = SCRIPT
86-
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
87-
os.environ["TERMINATION_LOG_FILE"] = tempdir + "/termination-log"
84+
def setup_env(monkeypatch, tempdir):
85+
monkeypatch.setenv("TRAINING_SCRIPT", SCRIPT)
86+
monkeypatch.setenv("PYTHONPATH", "./:$PYTHONPATH")
87+
monkeypatch.setenv("TERMINATION_LOG_FILE", os.path.join(tempdir, "termination-log"))
8888

8989

90-
def cleanup_env():
91-
os.environ.pop("TRAINING_SCRIPT", None)
92-
os.environ.pop("PYTHONPATH", None)
93-
os.environ.pop("TERMINATION_LOG_FILE", None)
94-
95-
96-
def test_successful_ft():
90+
def test_successful_ft(monkeypatch):
9791
"""Check if we can bootstrap and fine tune causallm models"""
9892
with tempfile.TemporaryDirectory() as tempdir:
99-
setup_env(tempdir)
93+
setup_env(monkeypatch, tempdir)
10094
TRAIN_KWARGS = {**BASE_KWARGS, **{"output_dir": tempdir}}
10195
serialized_args = serialize_args(TRAIN_KWARGS)
102-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
96+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
10397

10498
assert main() == 0
10599
_validate_termination_files_when_tuning_succeeds(tempdir)
@@ -108,43 +102,43 @@ def test_successful_ft():
108102

109103

110104
@pytest.mark.skipif(True, reason="This test is deprecated so always skipped")
111-
def test_successful_pt():
105+
def test_successful_pt(monkeypatch):
112106
"""Check if we can bootstrap and peft tune causallm models"""
113107
with tempfile.TemporaryDirectory() as tempdir:
114-
setup_env(tempdir)
108+
setup_env(monkeypatch, tempdir)
115109
TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}}
116110
serialized_args = serialize_args(TRAIN_KWARGS)
117-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
111+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
118112

119113
assert main() == 0
120114
_validate_termination_files_when_tuning_succeeds(tempdir)
121115
checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir))
122116
_validate_training_output(checkpoint, "pt")
123117

124118

125-
def test_successful_lora():
119+
def test_successful_lora(monkeypatch):
126120
"""Check if we can bootstrap and LoRA tune causallm models"""
127121
with tempfile.TemporaryDirectory() as tempdir:
128-
setup_env(tempdir)
122+
setup_env(monkeypatch, tempdir)
129123
TRAIN_KWARGS = {**BASE_LORA_KWARGS, **{"output_dir": tempdir}}
130124
serialized_args = serialize_args(TRAIN_KWARGS)
131-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
125+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
132126

133127
assert main() == 0
134128
_validate_termination_files_when_tuning_succeeds(tempdir)
135129
checkpoint = os.path.join(tempdir, get_highest_checkpoint(tempdir))
136130
_validate_training_output(checkpoint, "lora")
137131

138132

139-
def test_lora_save_model_dir_separate_dirs():
133+
def test_lora_save_model_dir_separate_dirs(monkeypatch):
140134
"""Run LoRA tuning with separate save_model_dir and output_dir.
141135
Verify model saved to save_model_dir and checkpoints saved to
142136
output_dir.
143137
"""
144138
with tempfile.TemporaryDirectory() as tempdir:
145139
output_dir = os.path.join(tempdir, "output_dir")
146140
save_model_dir = os.path.join(tempdir, "save_model_dir")
147-
setup_env(tempdir)
141+
setup_env(monkeypatch, tempdir)
148142
TRAIN_KWARGS = {
149143
**BASE_LORA_KWARGS,
150144
**{
@@ -154,7 +148,7 @@ def test_lora_save_model_dir_separate_dirs():
154148
},
155149
}
156150
serialized_args = serialize_args(TRAIN_KWARGS)
157-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
151+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
158152

159153
assert main() == 0
160154
_validate_termination_files_when_tuning_succeeds(output_dir)
@@ -165,12 +159,12 @@ def test_lora_save_model_dir_separate_dirs():
165159
assert len(checkpoints) == 1
166160

167161

168-
def test_lora_save_model_dir_same_dir_as_output_dir():
162+
def test_lora_save_model_dir_same_dir_as_output_dir(monkeypatch):
169163
"""Run LoRA tuning with same save_model_dir and output_dir.
170164
Verify checkpoints, logs, and model saved to path.
171165
"""
172166
with tempfile.TemporaryDirectory() as tempdir:
173-
setup_env(tempdir)
167+
setup_env(monkeypatch, tempdir)
174168
TRAIN_KWARGS = {
175169
**BASE_LORA_KWARGS,
176170
**{
@@ -180,7 +174,7 @@ def test_lora_save_model_dir_same_dir_as_output_dir():
180174
},
181175
}
182176
serialized_args = serialize_args(TRAIN_KWARGS)
183-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
177+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
184178

185179
assert main() == 0
186180
# check logs, checkpoint dir, and model exists in path
@@ -195,19 +189,21 @@ def test_lora_save_model_dir_same_dir_as_output_dir():
195189
assert len(checkpoints) == TRAIN_KWARGS["num_train_epochs"]
196190

197191

198-
def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no():
192+
def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no(
193+
monkeypatch,
194+
):
199195
"""Run LoRA tuning with same save_model_dir and output_dir and
200196
save_strategy=no. Verify no checkpoints created, only
201197
logs and final model.
202198
"""
203199
with tempfile.TemporaryDirectory() as tempdir:
204-
setup_env(tempdir)
200+
setup_env(monkeypatch, tempdir)
205201
TRAIN_KWARGS = {
206202
**BASE_LORA_KWARGS,
207203
**{"output_dir": tempdir, "save_model_dir": tempdir, "save_strategy": "no"},
208204
}
209205
serialized_args = serialize_args(TRAIN_KWARGS)
210-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
206+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
211207

212208
assert main() == 0
213209
# check that model and logs exists in output_dir
@@ -219,9 +215,9 @@ def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no():
219215
assert len(checkpoints) == 0
220216

221217

222-
def test_lora_with_lora_post_process_for_vllm_set_to_true():
218+
def test_lora_with_lora_post_process_for_vllm_set_to_true(monkeypatch):
223219
with tempfile.TemporaryDirectory() as tempdir:
224-
setup_env(tempdir)
220+
setup_env(monkeypatch, tempdir)
225221
TRAIN_KWARGS = {
226222
**BASE_LORA_KWARGS,
227223
**{
@@ -231,7 +227,7 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
231227
},
232228
}
233229
serialized_args = serialize_args(TRAIN_KWARGS)
234-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
230+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
235231

236232
assert main() == 0
237233
# check that model and logs exists in output_dir
@@ -255,9 +251,9 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
255251
not _is_package_available("HFResourceScanner"),
256252
reason="Only runs if HFResourceScanner is installed",
257253
)
258-
def test_launch_with_HFResourceScanner_enabled():
254+
def test_launch_with_HFResourceScanner_enabled(monkeypatch):
259255
with tempfile.TemporaryDirectory() as tempdir:
260-
setup_env(tempdir)
256+
setup_env(monkeypatch, tempdir)
261257
scanner_outfile = os.path.join(tempdir, TrackerConfigs.scanner_output_filename)
262258
TRAIN_KWARGS = {
263259
**BASE_LORA_KWARGS,
@@ -271,7 +267,7 @@ def test_launch_with_HFResourceScanner_enabled():
271267
},
272268
}
273269
serialized_args = serialize_args(TRAIN_KWARGS)
274-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
270+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
275271

276272
assert main() == 0
277273
assert os.path.exists(scanner_outfile) is True
@@ -281,14 +277,14 @@ def test_launch_with_HFResourceScanner_enabled():
281277
assert scanner_res["mem_data"] is not None
282278

283279

284-
def test_bad_script_path():
280+
def test_bad_script_path(monkeypatch):
285281
"""Check for appropriate error for an invalid training script location"""
286282
with tempfile.TemporaryDirectory() as tempdir:
287-
setup_env(tempdir)
283+
setup_env(monkeypatch, tempdir)
288284
TRAIN_KWARGS = {**BASE_LORA_KWARGS, **{"output_dir": tempdir}}
289285
serialized_args = serialize_args(TRAIN_KWARGS)
290-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
291-
os.environ["TRAINING_SCRIPT"] = "/not/here"
286+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
287+
monkeypatch.setenv("TRAINING_SCRIPT", "/not/here")
292288

293289
with pytest.raises(SystemExit) as pytest_wrapped_e:
294290
main()
@@ -297,61 +293,61 @@ def test_bad_script_path():
297293
assert os.stat(tempdir + "/termination-log").st_size > 0
298294

299295

300-
def test_blank_env_var():
296+
def test_blank_env_var(monkeypatch):
301297
with tempfile.TemporaryDirectory() as tempdir:
302-
setup_env(tempdir)
303-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = ""
298+
setup_env(monkeypatch, tempdir)
299+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", "")
304300
with pytest.raises(SystemExit) as pytest_wrapped_e:
305301
main()
306302
assert pytest_wrapped_e.type == SystemExit
307303
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
308304
assert os.stat(tempdir + "/termination-log").st_size > 0
309305

310306

311-
def test_faulty_file_path():
307+
def test_faulty_file_path(monkeypatch):
312308
with tempfile.TemporaryDirectory() as tempdir:
313-
setup_env(tempdir)
309+
setup_env(monkeypatch, tempdir)
314310
faulty_path = os.path.join(tempdir, "non_existent_file.pkl")
315311
TRAIN_KWARGS = {
316312
**BASE_LORA_KWARGS,
317313
**{"training_data_path": faulty_path, "output_dir": tempdir},
318314
}
319315
serialized_args = serialize_args(TRAIN_KWARGS)
320-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
316+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
321317
with pytest.raises(SystemExit) as pytest_wrapped_e:
322318
main()
323319
assert pytest_wrapped_e.type == SystemExit
324320
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
325321
assert os.stat(tempdir + "/termination-log").st_size > 0
326322

327323

328-
def test_bad_base_model_path():
324+
def test_bad_base_model_path(monkeypatch):
329325
with tempfile.TemporaryDirectory() as tempdir:
330-
setup_env(tempdir)
326+
setup_env(monkeypatch, tempdir)
331327
TRAIN_KWARGS = {
332328
**BASE_LORA_KWARGS,
333329
**{"model_name_or_path": "/wrong/path"},
334330
"output_dir": tempdir,
335331
}
336332
serialized_args = serialize_args(TRAIN_KWARGS)
337-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
333+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
338334
with pytest.raises(SystemExit) as pytest_wrapped_e:
339335
main()
340336
assert pytest_wrapped_e.type == SystemExit
341337
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
342338
assert os.stat(tempdir + "/termination-log").st_size > 0
343339

344340

345-
def test_config_parsing_error():
341+
def test_config_parsing_error(monkeypatch):
346342
with tempfile.TemporaryDirectory() as tempdir:
347-
setup_env(tempdir)
343+
setup_env(monkeypatch, tempdir)
348344
TRAIN_KWARGS = {
349345
**BASE_LORA_KWARGS,
350346
**{"num_train_epochs": "five"},
351347
"output_dir": tempdir,
352348
} # Intentional type error
353349
serialized_args = serialize_args(TRAIN_KWARGS)
354-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
350+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
355351
with pytest.raises(SystemExit) as pytest_wrapped_e:
356352
main()
357353
assert pytest_wrapped_e.type == SystemExit
@@ -376,9 +372,3 @@ def _validate_training_output(base_dir, tuning_technique):
376372
else:
377373
assert os.path.exists(base_dir + "/adapter_config.json") is True
378374
assert os.path.exists(base_dir + "/adapter_model.safetensors") is True
379-
380-
381-
def test_cleanup():
382-
# This runs to unset env variables that could disrupt other tests
383-
cleanup_env()
384-
assert True

tests/build/test_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,17 @@ def test_process_accelerate_launch_args(job_config):
5555

5656

5757
@patch("torch.cuda.device_count", return_value=1)
58-
def test_accelerate_launch_args_user_set_num_processes_ignored(job_config):
58+
def test_accelerate_launch_args_user_set_num_processes_ignored(
59+
_mock_cuda_count, job_config, monkeypatch
60+
):
5961
job_config_copy = copy.deepcopy(job_config)
6062
job_config_copy["accelerate_launch_args"]["num_processes"] = "3"
63+
if "CUDA_VISIBLE_DEVICES" in os.environ:
64+
monkeypatch.setenv(
65+
"CUDA_VISIBLE_DEVICES", os.environ["CUDA_VISIBLE_DEVICES"]
66+
)
67+
else:
68+
monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
6169
args = process_accelerate_launch_args(job_config_copy)
6270
# determine number of processes by number of GPUs available
6371
assert args.num_processes == 1

tests/test_sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def test_run_causallm_alora_and_inference(request, target_modules, expected):
801801
assert "Simply put, the theory of relativity states that \n" in output_inference
802802

803803

804-
def test_successful_lora_target_modules_default_from_main():
804+
def test_successful_lora_target_modules_default_from_main(monkeypatch):
805805
"""Check that if target_modules is not set, or set to None via JSON, the
806806
default value by model type will be using in LoRA tuning.
807807
The correct default target modules will be used for model type llama
@@ -818,7 +818,7 @@ def test_successful_lora_target_modules_default_from_main():
818818
**{"peft_method": "lora", "output_dir": tempdir},
819819
}
820820
serialized_args = serialize_args(TRAIN_KWARGS)
821-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
821+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", serialized_args)
822822

823823
sft_trainer.main()
824824

tests/utils/test_config_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,28 +208,30 @@ def test_update_config_can_handle_multiple_config_updates():
208208
assert config[1].r == 98
209209

210210

211-
def test_get_json_config_can_load_from_path():
211+
def test_get_json_config_can_load_from_path(monkeypatch):
212212
"""Test that the function get_json_config can read
213213
the json path from env var SFT_TRAINER_CONFIG_JSON_PATH
214214
"""
215-
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ:
216-
del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"]
217-
os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH
215+
monkeypatch.delenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", raising=False)
216+
monkeypatch.setenv(
217+
"SFT_TRAINER_CONFIG_JSON_PATH", HAPPY_PATH_DUMMY_CONFIG_PATH
218+
)
218219

219220
job_config = config_utils.get_json_config()
220221
assert job_config is not None
221222
assert job_config["model_name_or_path"] == "bigscience/bloom-560m"
222223

223224

224-
def test_get_json_config_can_load_from_envvar():
225+
def test_get_json_config_can_load_from_envvar(monkeypatch):
225226
"""Test that the function get_json_config can read
226227
the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR
227228
"""
228229
config_json = {"model_name_or_path": "foobar"}
229230
message_bytes = pickle.dumps(config_json)
230231
base64_bytes = base64.b64encode(message_bytes)
231232
encoded_json = base64_bytes.decode("ascii")
232-
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = encoded_json
233+
monkeypatch.delenv("SFT_TRAINER_CONFIG_JSON_PATH", raising=False)
234+
monkeypatch.setenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR", encoded_json)
233235

234236
job_config = config_utils.get_json_config()
235237
assert job_config is not None

0 commit comments

Comments
 (0)