Skip to content

Commit 2b59fd5

Browse files
authored
Fix some bugs for --lazy in CLI (#179)
* Rebasing Signed-off-by: Marc Romeyn <[email protected]> * Run linting Signed-off-by: Marc Romeyn <[email protected]> * Fix failing test Signed-off-by: Marc Romeyn <[email protected]> * Fix wrong rebase Signed-off-by: Marc Romeyn <[email protected]> * Fix wrong rebase Signed-off-by: Marc Romeyn <[email protected]> --------- Signed-off-by: Marc Romeyn <[email protected]>
1 parent 99b3914 commit 2b59fd5

File tree

3 files changed

+60
-9
lines changed

3 files changed

+60
-9
lines changed

nemo_run/cli/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,8 @@ def list_factories(type_or_namespace: Type | str) -> list[Callable]:
533533

534534

535535
def create_cli(
536-
add_verbose_callback: bool = False, nested_entrypoints_creation: bool = True
536+
add_verbose_callback: bool = False,
537+
nested_entrypoints_creation: bool = True,
537538
) -> Typer:
538539
app: Typer = Typer(pretty_exceptions_enable=False)
539540
entrypoints = metadata.entry_points().select(group="nemo_run.cli")
@@ -960,14 +961,14 @@ def command(
960961
if default_plugins:
961962
self.plugins = default_plugins
962963

964+
_load_workspace()
963965
if isinstance(fn, LazyEntrypoint):
964966
self.execute_lazy(fn, sys.argv, name)
965967
return
966968

967969
try:
968970
if not is_main:
969971
_load_entrypoints()
970-
_load_workspace()
971972
self.cli_execute(fn, ctx.args, type)
972973
except RunContextError as e:
973974
if not verbose:

nemo_run/cli/lazy.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,11 @@ def __post_init__(self):
438438
self.script = script_path.read_text()
439439
if len(cmd) > 1:
440440
self.import_path = " ".join(cmd[1:])
441-
if cmd[0] in ("nemo", "nemo_run"):
441+
if (
442+
cmd[0] in ("nemo", "nemo_run")
443+
or cmd[0].endswith("/nemo")
444+
or cmd[0].endswith("/nemo_run")
445+
):
442446
self.import_path = " ".join(cmd[1:])
443447

444448
def __call__(self, *args, **kwargs):
@@ -460,18 +464,61 @@ def _load_entrypoint(self):
460464
else:
461465
parts = self.import_path.split(" ")
462466
if parts[0] not in entrypoints:
467+
available_cmds = ", ".join(sorted(entrypoints.keys()))
463468
raise ValueError(
464-
f"Entrypoint {parts[0]} not found. Available entrypoints: {list(entrypoints.keys())}"
469+
f"Entrypoint '{parts[0]}' not found. Available top-level entrypoints: {available_cmds}"
465470
)
466471
output = entrypoints[parts[0]]
472+
473+
# Re-key the nested entrypoint dict to include 'name' attribute as keys
474+
def rekey_entrypoints(entries):
475+
if not isinstance(entries, dict):
476+
return entries
477+
478+
result = {}
479+
for key, value in entries.items():
480+
result[key] = value
481+
if hasattr(value, "name") and value.name != key:
482+
result[value.name] = value
483+
elif isinstance(value, dict):
484+
result[key] = rekey_entrypoints(value)
485+
return result
486+
487+
# Only rekey if we're dealing with a dictionary
488+
if isinstance(output, dict):
489+
output = rekey_entrypoints(output)
490+
467491
if len(parts) > 1:
468492
for part in parts[1:]:
469-
if part in output:
470-
output = output[part]
493+
# Skip args with - or -- prefix or containing = as they're parameters, not subcommands
494+
if part.startswith("-") or "=" in part:
495+
continue
496+
497+
if isinstance(output, dict):
498+
if part in output:
499+
output = output[part]
500+
else:
501+
# Collect available commands for error message
502+
available_cmds = sorted(output.keys())
503+
raise ValueError(
504+
f"Subcommand '{part}' not found for entrypoint '{parts[0]}'. "
505+
f"Available subcommands: {', '.join(available_cmds)}"
506+
)
471507
else:
508+
# We've reached an entrypoint object but tried to access a subcommand
509+
entrypoint_name = getattr(output, "name", parts[0])
472510
raise ValueError(
473-
f"Entrypoint {self.import_path} not found. Available entrypoints: {list(entrypoints.keys())}"
511+
f"'{entrypoint_name}' is a terminal entrypoint and does not have subcommand '{part}'. "
512+
f"You may have provided an incorrect command structure."
474513
)
514+
515+
# If output is a dict, we need to get the default entrypoint
516+
if isinstance(output, dict):
517+
raise ValueError(
518+
f"Incomplete command: '{self.import_path}'. Please specify a subcommand. "
519+
f"Available subcommands: {', '.join(sorted(output.keys()))}"
520+
)
521+
475522
self._target_fn = output.fn
476523

477524
@property
@@ -831,8 +878,8 @@ def load_config_from_path(path_with_syntax: str) -> Any:
831878
Examples:
832879
# Nested config (model.yaml):
833880
model:
834-
_target_: Model
835-
hidden_size: 256
881+
_target_: Model
882+
hidden_size: 256
836883
837884
# Flat config (model.yaml):
838885
_target_: Model

nemo_run/run/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def run(
7070
if plugins:
7171
plugins = [plugins] if not isinstance(plugins, list) else plugins
7272

73+
if getattr(fn_or_script, "is_lazy", False):
74+
fn_or_script = fn_or_script.resolve()
75+
7376
default_name = (
7477
fn_or_script.get_name()
7578
if isinstance(fn_or_script, Script)

0 commit comments

Comments
 (0)