Skip to content

Commit 2d3271e

Browse files
authored
Handle ctx in entrypoint for experiment (#213)
* Update LazyEntrypoint.resolve to handle ctx Signed-off-by: Marc Romeyn <[email protected]> * Update LazyEntrypoint.resolve to handle ctx Signed-off-by: Marc Romeyn <[email protected]> * Fixing failing tests Signed-off-by: Marc Romeyn <[email protected]> * Fix linting issue Signed-off-by: Marc Romeyn <[email protected]> * Trying to fix failing tests Signed-off-by: Marc Romeyn <[email protected]> --------- Signed-off-by: Marc Romeyn <[email protected]>
1 parent 0d271a9 commit 2d3271e

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

examples/entrypoint/experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,4 @@ def train_models_experiment(
104104

105105

106106
if __name__ == "__main__":
107-
run.cli.main(train_models_experiment)
107+
run.cli.main(train_models_experiment, default_executor=local_executor())

nemo_run/cli/api.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,10 +1232,28 @@ def parse_fn(self, fn: T, args: List[str], **default_kwargs) -> Partial[T]:
12321232
Returns:
12331233
Partial[T]: A Partial object representing the parsed function and arguments.
12341234
"""
1235-
output = LazyEntrypoint(fn, factory=self.factory, yaml=self.yaml, overwrites=args)
1236-
out = output.resolve()
1235+
lazy = LazyEntrypoint(
1236+
fn,
1237+
factory=self.factory,
1238+
yaml=self.yaml,
1239+
overwrites=args,
1240+
)
12371241

1238-
return out
1242+
# Resolve exactly once and always pass the current RunContext
1243+
# NOTE: `LazyEntrypoint.resolve` calls `parse_factory` if
1244+
# `lazy._factory_` is a string. `parse_cli_args` that follows inside
1245+
# `resolve` used to see the **same** string and call `parse_factory`
1246+
# a second time. We temporarily clear `_factory_` right after the
1247+
# first resolution so that it cannot be triggered again.
1248+
1249+
_orig_factory = lazy._factory_
1250+
try:
1251+
result = lazy.resolve(ctx=self)
1252+
finally:
1253+
# Restore for potential further use
1254+
lazy._factory_ = _orig_factory
1255+
1256+
return result
12391257

12401258
def _parse_partial(self, fn: Callable, args: List[str], **default_args) -> Partial[T]:
12411259
"""

nemo_run/cli/lazy.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import shlex
88
import sys
99
from dataclasses import dataclass, field
10+
import inspect
1011
from pathlib import Path
1112
from types import ModuleType
12-
from typing import Any, Callable, Iterator
13+
from typing import Any, Callable, Iterator, Optional, TYPE_CHECKING
1314

1415
from fiddle import Buildable, daglish
1516
from fiddle._src import signatures
@@ -19,6 +20,9 @@
1920

2021
from nemo_run.config import Partial
2122

23+
if TYPE_CHECKING:
24+
from nemo_run.cli.cli_parser import RunContext
25+
2226

2327
@contextlib.contextmanager
2428
def lazy_imports(fallback_to_lazy: bool = False) -> Iterator[None]:
@@ -142,7 +146,7 @@ def __init__(
142146
if remaining_overwrites:
143147
self._add_overwrite(*remaining_overwrites)
144148

145-
def resolve(self) -> Partial:
149+
def resolve(self, ctx: Optional["RunContext"] = None) -> Partial:
146150
from nemo_run.cli.cli_parser import parse_cli_args, parse_factory
147151

148152
fn = self._target_
@@ -160,12 +164,26 @@ def resolve(self) -> Partial:
160164
if isinstance(fn, LazyTarget):
161165
fn = fn.target
162166

167+
_fn = fn
168+
if hasattr(fn, "__fn_or_cls__"):
169+
_fn = fn.__fn_or_cls__
170+
171+
sig = inspect.signature(_fn)
172+
param_names = sig.parameters.keys()
173+
163174
dotlist = dictconfig_to_dot_list(
164175
_args_to_dictconfig(self._args_), has_factory=self._factory_ is not None
165176
)
166177
_args = [f"{name}{op}{value}" for name, op, value in dotlist]
167178

168-
return parse_cli_args(fn, _args)
179+
out = parse_cli_args(fn, _args)
180+
181+
if "ctx" in param_names:
182+
if not ctx:
183+
raise ValueError("ctx is required for this function")
184+
out.ctx = ctx
185+
186+
return out
169187

170188
def __getattr__(self, item: str) -> "LazyEntrypoint":
171189
"""

0 commit comments

Comments
 (0)