@@ -438,7 +438,11 @@ def __post_init__(self):
438438 self .script = script_path .read_text ()
439439 if len (cmd ) > 1 :
440440 self .import_path = " " .join (cmd [1 :])
441- if cmd [0 ] in ("nemo" , "nemo_run" ):
441+ if (
442+ cmd [0 ] in ("nemo" , "nemo_run" )
443+ or cmd [0 ].endswith ("/nemo" )
444+ or cmd [0 ].endswith ("/nemo_run" )
445+ ):
442446 self .import_path = " " .join (cmd [1 :])
443447
444448 def __call__ (self , * args , ** kwargs ):
@@ -460,18 +464,61 @@ def _load_entrypoint(self):
460464 else :
461465 parts = self .import_path .split (" " )
462466 if parts [0 ] not in entrypoints :
467+ available_cmds = ", " .join (sorted (entrypoints .keys ()))
463468 raise ValueError (
464- f"Entrypoint { parts [0 ]} not found. Available entrypoints: { list ( entrypoints . keys ()) } "
469+ f"Entrypoint ' { parts [0 ]} ' not found. Available top-level entrypoints: { available_cmds } "
465470 )
466471 output = entrypoints [parts [0 ]]
472+
473+ # Re-key the nested entrypoint dict to include 'name' attribute as keys
474+ def rekey_entrypoints (entries ):
475+ if not isinstance (entries , dict ):
476+ return entries
477+
478+ result = {}
479+ for key , value in entries .items ():
480+ result [key ] = value
481+ if hasattr (value , "name" ) and value .name != key :
482+ result [value .name ] = value
483+ elif isinstance (value , dict ):
484+ result [key ] = rekey_entrypoints (value )
485+ return result
486+
487+ # Only rekey if we're dealing with a dictionary
488+ if isinstance (output , dict ):
489+ output = rekey_entrypoints (output )
490+
467491 if len (parts ) > 1 :
468492 for part in parts [1 :]:
469- if part in output :
470- output = output [part ]
493+ # Skip args with - or -- prefix or containing = as they're parameters, not subcommands
494+ if part .startswith ("-" ) or "=" in part :
495+ continue
496+
497+ if isinstance (output , dict ):
498+ if part in output :
499+ output = output [part ]
500+ else :
501+ # Collect available commands for error message
502+ available_cmds = sorted (output .keys ())
503+ raise ValueError (
504+ f"Subcommand '{ part } ' not found for entrypoint '{ parts [0 ]} '. "
505+ f"Available subcommands: { ', ' .join (available_cmds )} "
506+ )
471507 else :
508+ # We've reached an entrypoint object but tried to access a subcommand
509+ entrypoint_name = getattr (output , "name" , parts [0 ])
472510 raise ValueError (
473- f"Entrypoint { self .import_path } not found. Available entrypoints: { list (entrypoints .keys ())} "
511+ f"'{ entrypoint_name } ' is a terminal entrypoint and does not have subcommand '{ part } '. "
512+ f"You may have provided an incorrect command structure."
474513 )
514+
515+ # If output is a dict, we need to get the default entrypoint
516+ if isinstance (output , dict ):
517+ raise ValueError (
518+ f"Incomplete command: '{ self .import_path } '. Please specify a subcommand. "
519+ f"Available subcommands: { ', ' .join (sorted (output .keys ()))} "
520+ )
521+
475522 self ._target_fn = output .fn
476523
477524 @property
@@ -831,8 +878,8 @@ def load_config_from_path(path_with_syntax: str) -> Any:
831878 Examples:
832879 # Nested config (model.yaml):
833880 model:
834- _target_: Model
835- hidden_size: 256
881+ _target_: Model
882+ hidden_size: 256
836883
837884 # Flat config (model.yaml):
838885 _target_: Model
0 commit comments