@@ -251,6 +251,32 @@ def test_save_hyperparameters(tmp_path):
251251 assert hparams ["bar" ] == 1
252252
253253
254+ def _test_function2 (out_dir : Path , foo : bool = False , bar : int = 1 ):
255+ assert False , "I only exist as a signature, but I should not run."
256+
257+
258+ @pytest .mark .parametrize ("command" , [
259+ "any.py" ,
260+ "litgpt finetune full" ,
261+ "litgpt finetune lora" ,
262+ "litgpt finetune adapter" ,
263+ "litgpt finetune adapter_v2" ,
264+ "litgpt pretrain" ,
265+ ])
266+ def test_save_hyperparameters_known_commands (command , tmp_path ):
267+ from litgpt .utils import save_hyperparameters
268+
269+ with mock .patch ("sys.argv" , [* command .split (" " ), "--out_dir" , str (tmp_path ), "--foo" , "True" ]):
270+ save_hyperparameters (_test_function2 , tmp_path )
271+
272+ with open (tmp_path / "hyperparameters.yaml" , "r" ) as file :
273+ hparams = yaml .full_load (file )
274+
275+ assert hparams ["out_dir" ] == str (tmp_path )
276+ assert hparams ["foo" ] is True
277+ assert hparams ["bar" ] == 1
278+
279+
254280def test_choose_logger (tmp_path ):
255281 from litgpt .utils import choose_logger
256282
0 commit comments