Skip to content

Commit 0e87190

Browse files
authored
Add validation for run.Partial based tasks in Experiment (#103)
* Add validation for run.Partial based tasks in Experiment Signed-off-by: Hemil Desai <[email protected]> * Add tests Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 59d8dae commit 0e87190

File tree

4 files changed

+75
-2
lines changed

4 files changed

+75
-2
lines changed

src/nemo_run/run/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def run(
7575
)
7676
name = name or default_name
7777
with Experiment(title=name, executor=executor, log_level=log_level) as exp:
78-
exp.add(fn_or_script, tail_logs=tail_logs, plugins=plugins)
78+
exp.add(fn_or_script, tail_logs=tail_logs, plugins=plugins, name=name)
7979
if dryrun:
8080
exp.dryrun()
8181
return

src/nemo_run/run/experiment.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import inspect
2020
import json
2121
import os
22+
import pprint
2223
import shutil
2324
import sys
2425
import time
@@ -29,6 +30,7 @@
2930

3031
import fiddle as fdl
3132
import networkx as nx
33+
from fiddle._src import daglish, diffing
3234
from rich.console import Group
3335
from rich.live import Live
3436
from rich.panel import Panel
@@ -263,6 +265,7 @@ def __init__(
263265
log_level: str = "INFO",
264266
_reconstruct: bool = False,
265267
jobs: list[Job | JobGroup] | None = None,
268+
base_dir: str | None = None,
266269
) -> None:
267270
"""
268271
Initializes an experiment run by creating its metadata directory and saving the experiment config.
@@ -286,7 +289,8 @@ def __init__(
286289
self._title = title
287290
self._id = id or f"{title}_{int(time.time())}"
288291

289-
self._exp_dir = os.path.join(NEMORUN_HOME, "experiments", title, self._id)
292+
base_dir = base_dir or NEMORUN_HOME
293+
self._exp_dir = os.path.join(base_dir, "experiments", title, self._id)
290294

291295
self.log_level = log_level
292296
self._runner = get_runner()
@@ -378,6 +382,8 @@ def _add_single_job(
378382
else:
379383
task_id = name
380384

385+
self._validate_task(task_info=task_id, task=task)
386+
381387
executor = executor.clone()
382388
executor.assign(
383389
self._id,
@@ -418,6 +424,9 @@ def _add_job_group(
418424
else:
419425
task_id = name
420426

427+
for i, _task in enumerate(tasks):
428+
self._validate_task(task_info=f"Job Group: {task_id}, job index: {i}", task=_task)
429+
421430
executors = executor if isinstance(executor, list) else [executor]
422431
cloned_executors = []
423432
for executor in executors:
@@ -449,6 +458,30 @@ def _add_job_group(
449458
self._jobs.append(job_group)
450459
return job_group.id
451460

461+
def _validate_task(self, task_info: str, task: Union[Partial, Script]) -> None:
462+
valid = True
463+
message = ""
464+
if isinstance(task, Partial):
465+
serializer = ZlibJSONSerializer()
466+
serialized = serializer.serialize(task)
467+
deserialized = serializer.deserialize(serialized)
468+
diff = diffing.build_diff(deserialized, task)
469+
diff = {
470+
daglish.path_str(d.target): (d.new_value if hasattr(d, "new_value") else None) # type: ignore
471+
for d in diff.changes
472+
}
473+
if deserialized != task:
474+
valid = False
475+
message += f"""
476+
Deserialized task does not match original task. The following paths in your task need to be wrapped in `run.Config` or `run.Partial`:
477+
478+
{pprint.PrettyPrinter(indent=4).pformat(diff)}
479+
480+
For more information about `run.Config` and `run.Partial`, please refer to https://github.com/NVIDIA/NeMo-Run/blob/main/docs/source/guides/configuration.md.
481+
"""
482+
if not valid:
483+
raise RuntimeError(f"Failed to validate task {task_info}.\n{message}")
484+
452485
def add(
453486
self,
454487
task: Union[Partial, Script] | list[Union[Partial, Script]],

test/dummy_factory.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ class DummyModel:
2525
activation: str = "relu"
2626

2727

28+
class DummyTrainer:
29+
def __init__(self, num_epochs: int = 10):
30+
self.num_epochs = num_epochs
31+
32+
def __hash__(self) -> int:
33+
return hash((self.num_epochs))
34+
35+
def __eq__(self, value: object) -> bool:
36+
return isinstance(value, DummyTrainer) and self.num_epochs == value.num_epochs
37+
38+
2839
@dataclass
2940
class NestedModel:
3041
dummy: DummyModel
@@ -88,5 +99,8 @@ def plugin_list(arg: int = 20) -> List[run.Plugin]:
8899
]
89100

90101

102+
def dummy_train(dummy_model: DummyModel, dummy_trainer: DummyTrainer): ...
103+
104+
91105
if __name__ == "__main__":
92106
run.cli.main(dummy_entrypoint)

test/run/test_experiment.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
from fiddle._src.experimental.serialization import UnserializableValueError
3+
4+
import nemo_run as run
5+
from test.dummy_factory import DummyModel, DummyTrainer, dummy_train
6+
7+
8+
@pytest.fixture
9+
def experiment(tmpdir):
10+
return run.Experiment("dummy_experiment", base_dir=tmpdir)
11+
12+
13+
class TestValidateTask:
14+
def test_validate_task(self, experiment: run.Experiment):
15+
experiment._validate_task("valid_script", run.Script(inline="echo 'hello world'"))
16+
17+
valid_partial = run.Partial(
18+
dummy_train, dummy_model=run.Config(DummyModel), dummy_trainer=run.Config(DummyTrainer)
19+
)
20+
experiment._validate_task("valid_partial", valid_partial)
21+
22+
invalid_partial = run.Partial(
23+
dummy_train, dummy_model=DummyModel(), dummy_trainer=DummyTrainer()
24+
)
25+
with pytest.raises(UnserializableValueError):
26+
experiment._validate_task("invalid_partial", invalid_partial)

0 commit comments

Comments
 (0)