Skip to content

Commit 74cf902

Browse files
authored
Support Annotated in CLI (#118)
* Support Annotated in CLI Signed-off-by: Marc Romeyn <[email protected]> * Fix failing test Signed-off-by: Marc Romeyn <[email protected]> --------- Signed-off-by: Marc Romeyn <[email protected]>
1 parent 6c2eee0 commit 74cf902

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

src/nemo_run/cli/cli_parser.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,14 @@ def parse_buildable(self, value: str, annotation: Type[Config | Partial]) -> Con
702702
elif buildable_type == "Partial":
703703
return Partial(config_type)
704704

705+
if str(annotation).startswith("typing.Annotated"):
706+
args = get_args(annotation)
707+
if str(args[0]).startswith("typing.Optional") and len(args) > 1:
708+
cfg_type = get_args(args[0])[0]
709+
buildable = args[1].__origin__
710+
if issubclass(buildable, fdl.Buildable):
711+
return buildable(cfg_type)
712+
705713
return Config(annotation)
706714

707715
def parse_int(self, value: str, _: Type) -> int:

test/cli/test_api.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from configparser import ConfigParser
1717
from dataclasses import dataclass
18-
from typing import List, Optional, Union
18+
from typing import Annotated, List, Optional, Tuple, Union
1919
from unittest.mock import Mock, patch
2020

2121
import fiddle as fdl
@@ -439,9 +439,9 @@ class Model:
439439
class Optimizer:
440440
"""Dummy optimizer config"""
441441

442-
learning_rate: float
443-
weight_decay: float
444-
betas: List[float]
442+
learning_rate: float = 0.001
443+
weight_decay: float = 1e-5
444+
betas: Tuple[float, float] = (0.9, 0.999)
445445

446446

447447
@run.cli.factory
@@ -506,6 +506,22 @@ def train_model(
506506
return {"model": model, "optimizer": optimizer, "epochs": epochs, "batch_size": batch_size}
507507

508508

509+
@run.cli.entrypoint(
510+
namespace="my_llm",
511+
skip_confirmation=True,
512+
)
513+
def train_model_default_optimizer(
514+
model: Model,
515+
optimizer: Annotated[Optional[Optimizer], run.Config[Optimizer]] = None,
516+
epochs: int = 10,
517+
batch_size: int = 32,
518+
):
519+
if optimizer is None:
520+
optimizer = Optimizer()
521+
522+
return train_model(model, optimizer, epochs, batch_size)
523+
524+
509525
@run.cli.factory(target=train_model)
510526
def custom_defaults() -> run.Partial["train_model"]:
511527
return run.Partial(
@@ -585,6 +601,37 @@ def test_with_defaults(self, runner, app):
585601
for i in range(1, 31):
586602
assert f"Epoch {i}/30" in output
587603

604+
def test_with_defaults_no_optimizer(self, runner, app):
605+
# Test CLI execution with default factory
606+
result = runner.invoke(
607+
app,
608+
[
609+
"my_llm",
610+
"train_model_default_optimizer",
611+
"model=my_model(hidden_size=1024)",
612+
"epochs=30",
613+
"run.skip_confirmation=True",
614+
],
615+
env={"INCLUDE_WORKSPACE_FILE": "false"},
616+
)
617+
assert result.exit_code == 0
618+
619+
# Parse the output to check the values
620+
output = result.stdout
621+
assert "Training model with the following configuration:" in output
622+
assert "Model: Model(hidden_size=1024, num_layers=3, activation='relu')" in output
623+
assert (
624+
"Optimizer: Optimizer(learning_rate=0.001, weight_decay=1e-05, betas=(0.9, 0.999))"
625+
in output
626+
)
627+
assert "Epochs: 30" in output
628+
assert "Batch size: 32" in output
629+
assert "Training completed!" in output
630+
631+
# Check that all epochs were simulated
632+
for i in range(1, 31):
633+
assert f"Epoch {i}/30" in output
634+
588635
def test_experiment_entrypoint(self):
589636
def dummy_pretrain(log_dir: str):
590637
pass

0 commit comments

Comments
 (0)