66
77import argparse
88import logging
9- import sys
10- from typing import Callable , Optional
119
1210from tabulate import tabulate
1311
1412from torchx .cli .cmd_base import SubCommand
1513from torchx .runner .api import get_configured_trackers
1614from torchx .tracker .api import build_trackers , TrackerBase
17- from torchx .util .types import none_throws
1815
1916logger : logging .Logger = logging .getLogger (__name__ )
2017
2118
22- def _requires_tracker (
23- command : Callable [["CmdTracker" , argparse .Namespace ], None ]
24- ) -> Callable [["CmdTracker" , argparse .Namespace ], None ]:
25- """Checks that command has valid tracker setup"""
26-
27- def wrapper (self : "CmdTracker" , args : argparse .Namespace ) -> None :
28- if not self .tracker :
29- logger .error ("Exiting since no trackers were configured." )
30- sys .exit (1 )
31- command (self , args )
32-
33- return wrapper
34-
35-
3619class CmdTracker (SubCommand ):
3720 """
3821 Prototype TorchX tracker subcommand that allows querying data by
@@ -49,30 +32,28 @@ class CmdTracker(SubCommand):
4932 def __init__ (self ) -> None :
5033 """
5134 Queries available tracker implementations and uses the first available one.
52-
53- Since the instance needs to be available to setup torchx arguments, subcommands
54- utilize `_requires_tracker()` annotation to check that tracker is available
55- when invoked.
5635 """
57- self .tracker : Optional [TrackerBase ] = None
58- configured_trackers = get_configured_trackers ()
59- if configured_trackers :
60- trackers = build_trackers (configured_trackers )
61- if trackers :
62- self .tracker = next (iter (trackers ))
63- logger .info (f"Using { self .tracker } to query data" )
64- else :
65- logger .error ("No trackers were configured!" )
36+
37+ @property
38+ def tracker (self ) -> TrackerBase :
39+ trackers = list (build_trackers (get_configured_trackers ()))
40+ if trackers :
41+ logger .info (f"Using `{ trackers [0 ]} ` tracker to query data" )
42+ return trackers [0 ]
43+ else :
44+ raise RuntimeError (
45+ "No trackers configured."
46+ " See: https://pytorch.org/torchx/latest/runtime/tracking.html"
47+ )
6648
6749 def add_list_job_arguments (self , subparser : argparse .ArgumentParser ) -> None :
6850 subparser .add_argument (
6951 "--parent-run-id" , type = str , help = "Optional job parent run ID"
7052 )
7153
72- @_requires_tracker
7354 def list_jobs_command (self , args : argparse .Namespace ) -> None :
7455 parent_run_id = args .parent_run_id
75- job_ids = none_throws ( self .tracker ) .run_ids (parent_run_id = parent_run_id )
56+ job_ids = self .tracker .run_ids (parent_run_id = parent_run_id )
7657
7758 tabulated_job_ids = [[job_id ] for job_id in job_ids ]
7859 print (tabulate (tabulated_job_ids , headers = ["JOB ID" ]))
@@ -91,17 +72,15 @@ def add_job_lineage_arguments(self, subparser: argparse.ArgumentParser) -> None:
9172 )
9273 subparser .add_argument ("RUN_ID" , type = str , help = "Job run ID" )
9374
94- @_requires_tracker
9575 def job_lineage_command (self , args : argparse .Namespace ) -> None :
9676 raise NotImplementedError ("" )
9777
9878 def add_metadata_arguments (self , subparser : argparse .ArgumentParser ) -> None :
9979 subparser .add_argument ("RUN_ID" , type = str , help = "Job run ID" )
10080
101- @_requires_tracker
10281 def list_metadata_command (self , args : argparse .Namespace ) -> None :
10382 run_id = args .RUN_ID
104- metadata = none_throws ( self .tracker ) .metadata (run_id )
83+ metadata = self .tracker .metadata (run_id )
10584 print_data = [[k , v ] for k , v in metadata .items ()]
10685
10786 print (tabulate (print_data , headers = ["ID" , "VALUE" ]))
@@ -113,21 +92,16 @@ def add_artifacts_arguments(self, subparser: argparse.ArgumentParser) -> None:
11392
11493 subparser .add_argument ("RUN_ID" , type = str , help = "Job run ID" )
11594
116- @_requires_tracker
11795 def list_artifacts_command (self , args : argparse .Namespace ) -> None :
11896 run_id = args .RUN_ID
11997 artifact_filter = args .artifact
12098
121- artifacts = none_throws (self .tracker ).artifacts (run_id )
122- artifacts = artifacts .values ()
99+ artifacts = list (self .tracker .artifacts (run_id ).values ())
123100
124101 if artifact_filter :
125- artifacts = [
126- artifact for artifact in artifacts if artifact .name == artifact_filter
127- ]
128- print_data = [
129- [artifact .name , artifact .path , artifact .metadata ] for artifact in artifacts
130- ]
102+ artifacts = [a for a in artifacts if a .name == artifact_filter ]
103+
104+ print_data = [[a .name , a .path , a .metadata ] for a in artifacts ]
131105
132106 print (tabulate (print_data , headers = ["ARTIFACT" , "PATH" , "METADATA" ]))
133107
0 commit comments