Skip to content

Commit 5021cee

Browse files
awaelchlirasbt
authored andcommitted
Fix save_hyperparameters() for top-level CLI (#1103)
1 parent 1b6a16e commit 5021cee

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

litgpt/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,21 @@ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
403403
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
404404
from jsonargparse import capture_parser
405405

406+
# TODO: Make this more robust
407+
# This hack strips away the subcommands from the top-level CLI
408+
# to parse the file as if it was called as a script
409+
known_commands = [
410+
("finetune", "full"),
411+
("finetune", "lora"),
412+
("finetune", "adapter"),
413+
("finetune", "adapter_v2"),
414+
("pretrain",),
415+
]
416+
for known_command in known_commands:
417+
unwanted = slice(1, 1 + len(known_command))
418+
if tuple(sys.argv[unwanted]) == known_command:
419+
sys.argv[unwanted] = []
420+
406421
parser = capture_parser(lambda: CLI(function))
407422
config = parser.parse_args()
408423
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)

tests/test_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,32 @@ def test_save_hyperparameters(tmp_path):
251251
assert hparams["bar"] == 1
252252

253253

254+
def _test_function2(out_dir: Path, foo: bool = False, bar: int = 1):
255+
assert False, "I only exist as a signature, but I should not run."
256+
257+
258+
@pytest.mark.parametrize("command", [
259+
"any.py",
260+
"litgpt finetune full",
261+
"litgpt finetune lora",
262+
"litgpt finetune adapter",
263+
"litgpt finetune adapter_v2",
264+
"litgpt pretrain",
265+
])
266+
def test_save_hyperparameters_known_commands(command, tmp_path):
267+
from litgpt.utils import save_hyperparameters
268+
269+
with mock.patch("sys.argv", [*command.split(" "), "--out_dir", str(tmp_path), "--foo", "True"]):
270+
save_hyperparameters(_test_function2, tmp_path)
271+
272+
with open(tmp_path / "hyperparameters.yaml", "r") as file:
273+
hparams = yaml.full_load(file)
274+
275+
assert hparams["out_dir"] == str(tmp_path)
276+
assert hparams["foo"] is True
277+
assert hparams["bar"] == 1
278+
279+
254280
def test_choose_logger(tmp_path):
255281
from litgpt.utils import choose_logger
256282

0 commit comments

Comments
 (0)