Skip to content

Commit cd9196d

Browse files
authored
Fix bug with a CLI overwrite (#235)
Signed-off-by: Marc Romeyn <[email protected]>
1 parent 78f54ee commit cd9196d

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

nemo_run/cli/lazy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,10 @@ def _args_to_dictconfig(args: list[tuple[str, str, Any]]) -> DictConfig:
721721
if part not in current:
722722
current[part] = {}
723723
elif not isinstance(current[part], dict):
724-
current[part] = {} # Convert to dict if it wasn't already
724+
if isinstance(current[part], str):
725+
current[part] = {"_factory_": current[part]}
726+
else:
727+
current[part] = {} # Convert to dict if it wasn't already
725728
current = current[part]
726729

727730
# Add the operator suffix if it's not a simple assignment

test/cli/test_api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,13 @@ def my_model(hidden_size: int = 256, num_layers: int = 3, activation: str = "rel
631631
return Model(hidden_size=hidden_size, num_layers=num_layers, activation=activation)
632632

633633

634+
@run.cli.factory
635+
@run.autoconvert
636+
def my_other_model(hidden_size: int = 512, num_layers: int = 3, activation: str = "relu") -> Model:
637+
"""Create a model configuration."""
638+
return Model(hidden_size=hidden_size, num_layers=num_layers, activation=activation)
639+
640+
634641
@run.cli.factory
635642
def my_optimizer(
636643
learning_rate: float = 0.001, weight_decay: float = 1e-5, betas: List[float] = [0.9, 0.999]
@@ -880,6 +887,26 @@ class SomeObject:
880887
value_1: int
881888
value_2: int
882889

890+
def test_with_factory_and_overwrite(self, runner, app):
891+
# Test CLI execution with factory and parameter overwrite
892+
result = runner.invoke(
893+
app,
894+
[
895+
"my_llm",
896+
"train_model",
897+
"model=my_other_model",
898+
"model.num_layers=10",
899+
"--yes",
900+
],
901+
env={"INCLUDE_WORKSPACE_FILE": "false"},
902+
)
903+
assert result.exit_code == 0
904+
905+
output = result.stdout
906+
assert "Training model with the following configuration:" in output
907+
# Check that my_model_2's default hidden_size (512) is used
908+
assert "Model: Model(hidden_size=512, num_layers=10, activation='relu')" in output
909+
883910

884911
class TestDefaultFactory:
885912
def test_default_factory(self):

0 commit comments

Comments
 (0)