Skip to content

Commit c5c1f6f

Browse files
committed
Add option to customize experiment dir
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 143b17c commit c5c1f6f

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

nemo_run/run/experiment.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535
from rich.console import Group
3636
from rich.live import Live
3737
from rich.panel import Panel
38-
from rich.progress import BarColumn, Progress, SpinnerColumn
38+
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskID, TimeElapsedColumn
3939
from rich.progress import Task as RichTask
40-
from rich.progress import TaskID, TimeElapsedColumn
4140
from rich.syntax import Syntax
4241
from torchx.specs.api import AppState
4342

@@ -225,11 +224,13 @@ class Experiment(ConfigurableMixin):
225224
def catalog(
226225
cls: Type["Experiment"],
227226
title: str = "",
227+
exp_dir_infix: str | None = None,
228228
) -> list[str]:
229229
"""
230230
List all experiments inside get_nemorun_home(), optionally with the provided title.
231231
"""
232-
parent_dir = os.path.join(get_nemorun_home(), "experiments", title)
232+
exp_dir_infix = exp_dir_infix or os.path.join("experiments", title)
233+
parent_dir = os.path.join(get_nemorun_home(), exp_dir_infix)
233234
return _get_sorted_dirs(parent_dir)
234235

235236
@classmethod
@@ -263,12 +264,13 @@ def _from_config(cls: Type["Experiment"], exp_dir: str) -> "Experiment":
263264
def from_id(
264265
cls: Type["Experiment"],
265266
id: str,
267+
exp_dir_infix: str | None = None,
266268
) -> "Experiment":
267269
"""
268270
Reconstruct an experiment with the specified id.
269271
"""
270-
title, _, _ = id.rpartition("_")
271-
parent_dir = os.path.join(get_nemorun_home(), "experiments", title)
272+
exp_dir_infix = exp_dir_infix or os.path.join("experiments", id.rpartition("_")[0])
273+
parent_dir = os.path.join(get_nemorun_home(), exp_dir_infix)
272274
exp_dir = os.path.join(parent_dir, id)
273275

274276
assert os.path.isdir(exp_dir), f"Experiment {id} not found."
@@ -280,11 +282,13 @@ def from_id(
280282
def from_title(
281283
cls: Type["Experiment"],
282284
title: str,
285+
exp_dir_infix: str | None = None,
283286
) -> "Experiment":
284287
"""
285288
Reconstruct an experiment with the specified title.
286289
"""
287-
parent_dir = os.path.join(get_nemorun_home(), "experiments", title)
290+
exp_dir_infix = exp_dir_infix or os.path.join("experiments", title)
291+
parent_dir = os.path.join(get_nemorun_home(), exp_dir_infix)
288292
exp_dir = _get_latest_dir(parent_dir)
289293

290294
assert os.path.isdir(exp_dir), f"Experiment {id} not found."
@@ -303,6 +307,7 @@ def __init__(
303307
base_dir: str | None = None,
304308
clean_mode: bool = False,
305309
enable_goodbye_message: bool = True,
310+
exp_dir_infix: str | None = None,
306311
) -> None:
307312
"""
308313
Initializes an experiment run by creating its metadata directory and saving the experiment config.
@@ -330,7 +335,8 @@ def __init__(
330335
self._enable_goodbye_message = enable_goodbye_message
331336

332337
base_dir = str(base_dir or get_nemorun_home())
333-
self._exp_dir = os.path.join(base_dir, "experiments", title, self._id)
338+
self._exp_dir_infix = exp_dir_infix or os.path.join("experiments", title)
339+
self._exp_dir = os.path.join(base_dir, self._exp_dir_infix, self._id)
334340

335341
self.log_level = log_level
336342
self._runner = get_runner(component_defaults=None, experiment=self)
@@ -359,6 +365,7 @@ def to_config(self) -> Config:
359365
executor=self.executor.to_config(),
360366
log_level=self.log_level,
361367
clean_mode=self.clean_mode,
368+
exp_dir_infix=self._exp_dir_infix,
362369
)
363370

364371
def _save_experiment(self, exist_ok: bool = False):
@@ -997,7 +1004,7 @@ def reset(self) -> "Experiment":
9971004

9981005
old_id, old_exp_dir, old_launched = self._id, self._exp_dir, self._launched
9991006
self._id = f"{self._title}_{int(time.time())}"
1000-
self._exp_dir = os.path.join(get_nemorun_home(), "experiments", self._title, self._id)
1007+
self._exp_dir = os.path.join(get_nemorun_home(), self._exp_dir_infix, self._id)
10011008
self._launched = False
10021009
self._live_progress = None
10031010

@@ -1047,7 +1054,7 @@ def reset(self) -> "Experiment":
10471054
f"[bold magenta]Failed resetting Experiment {self._id} due to error: {e}"
10481055
)
10491056
# Double check exp dir is unchanged
1050-
new_path = os.path.join(get_nemorun_home(), "experiments", self._title, self._id)
1057+
new_path = os.path.join(get_nemorun_home(), self._exp_dir_infix, self._id)
10511058
if self._exp_dir == new_path and new_path != old_exp_dir:
10521059
shutil.rmtree(self._exp_dir)
10531060

0 commit comments

Comments
 (0)