Skip to content

Commit 9fc7481

Browse files
authored
Add option to return dict in Experiment.status (#91)
Signed-off-by: Hemil Desai <[email protected]>
1 parent c0446bf commit 9fc7481

File tree

1 file changed

+55
-32
lines changed

1 file changed

+55
-32
lines changed

src/nemo_run/run/experiment.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def set_context(context: contextvars.Context):
726726
finally:
727727
job.cleanup()
728728

729-
def status(self):
729+
def status(self, return_dict: bool = False) -> Optional[dict[str, str]]:
730730
"""
731731
Prints a table specifying the status of all tasks.
732732
@@ -740,41 +740,64 @@ def status(self):
740740
_current_experiment.set(self)
741741
_set_current_experiment = True
742742

743+
def _get_job_info_and_dict(
744+
idx: int, job: Job | JobGroup
745+
) -> tuple[list[str], dict[str, str]]:
746+
job_info = []
747+
job_info.append(f"[bold green]Task {idx}[/bold green]: [bold orange1]{job.id}")
748+
job_info.append(
749+
f"- [bold green]Status[/bold green]: {str(job.status(runner=self._runner))}"
750+
)
751+
job_info.append(f"- [bold green]Executor[/bold green]: {job.executor.info()}")
752+
753+
try:
754+
_, _, path_str = job.handle.partition("://")
755+
path = path_str.split("/")
756+
app_id = path[1]
757+
except Exception:
758+
app_id = ""
759+
760+
job_info.append(f"- [bold green]Job id[/bold green]: {app_id}")
761+
directory_info = [
762+
"- [bold green]Local Directory[/bold green]: " + job.executor.job_dir,
763+
]
764+
job_dict = {
765+
"name": job.id,
766+
"status": job.status(runner=self._runner),
767+
"executor": job.executor.info(),
768+
"job_id": app_id,
769+
"local_dir": job.executor.job_dir,
770+
}
771+
772+
if isinstance(job.executor, SlurmExecutor) and isinstance(
773+
job.executor.tunnel, SSHTunnel
774+
):
775+
directory_info.extend(
776+
[
777+
"- [bold green]Remote Directory[/bold green]: "
778+
+ os.path.join(
779+
job.executor.tunnel.job_dir,
780+
Path(job.executor.job_dir).name,
781+
),
782+
]
783+
)
784+
job_dict["remote_dir"] = os.path.join(
785+
job.executor.tunnel.job_dir,
786+
Path(job.executor.job_dir).name,
787+
)
788+
job_info.extend(directory_info)
789+
return job_info, job_dict
790+
743791
try:
792+
result_dict = {}
744793
job_infos = []
745794
for i, job in enumerate(self.jobs):
746-
job_info = []
747-
job_info.append(f"[bold green]Task {i}[/bold green]: [bold orange1]{job.id}")
748-
job_info.append(
749-
f"- [bold green]Status[/bold green]: {str(job.status(runner=self._runner))}"
750-
)
751-
job_info.append(f"- [bold green]Executor[/bold green]: {job.executor.info()}")
752-
753-
try:
754-
_, _, path_str = job.handle.partition("://")
755-
path = path_str.split("/")
756-
app_id = path[1]
757-
except Exception:
758-
app_id = ""
759-
760-
job_info.append(f"- [bold green]Job id[/bold green]: {app_id}")
761-
directory_info = [
762-
"- [bold green]Local Directory[/bold green]: " + job.executor.job_dir,
763-
]
764-
if isinstance(job.executor, SlurmExecutor) and isinstance(
765-
job.executor.tunnel, SSHTunnel
766-
):
767-
directory_info.extend(
768-
[
769-
"- [bold green]Remote Directory[/bold green]: "
770-
+ os.path.join(
771-
job.executor.tunnel.job_dir,
772-
Path(job.executor.job_dir).name,
773-
),
774-
]
775-
)
776-
job_info.extend(directory_info)
795+
job_info, job_dict = _get_job_info_and_dict(i, job)
777796
job_infos.append(Group(*job_info))
797+
result_dict[job.id] = job_dict
798+
799+
if return_dict:
800+
return result_dict
778801

779802
self.console.print()
780803
self.console.print(

0 commit comments

Comments
 (0)