55# LICENSE file in the root directory of this source tree.
66
77import argparse
8- import glob
9- import importlib
10- import os
11- from dataclasses import dataclass
12- from inspect import getmembers , isfunction
13- from typing import Dict , List , Optional , Union
8+ from dataclasses import asdict
9+ from pprint import pformat
10+ from typing import Dict , List , Union , cast
1411
1512import torchx .specs as specs
1613from pyre_extensions import none_throws
1714from torchx .cli .cmd_base import SubCommand
1815from torchx .runner import get_runner
19- from torchx .specs .file_linter import get_fn_docstring , validate
20- from torchx .util import entrypoints
21- from torchx .util .io import COMPONENTS_DIR , get_abspath , read_conf_file
16+ from torchx .specs .finder import get_components , _Component
2217from torchx .util .types import to_dict
2318
2419
@@ -38,96 +33,19 @@ def _parse_run_config(arg: str) -> specs.RunConfig:
3833 return conf
3934
4035
41- def _to_module (filepath : str ) -> str :
42- path , _ = os .path .splitext (filepath )
43- return path .replace (os .path .sep , "." )
44-
45-
46- def _get_builtin_description (filepath : str , function_name : str ) -> Optional [str ]:
47- source = read_conf_file (filepath )
48- if len (validate (source , torchx_function = function_name )) != 0 :
49- return None
50-
51- func_definition , _ = none_throws (get_fn_docstring (source , function_name ))
52- return func_definition
53-
54-
55- @dataclass
56- class BuiltinComponent :
57- definition : str
58- description : str
59-
60-
61- def _get_component_definition (module : str , function_name : str ) -> str :
62- if module .startswith ("torchx.components" ):
63- module = module .split ("torchx.components." )[1 ]
64- return f"{ module } .{ function_name } "
65-
66-
67- def _to_relative (filepath : str ) -> str :
68- if os .path .isabs (filepath ):
69- # make path torchx/components/$suffix out of the abs
70- rel_path = filepath .split (str (COMPONENTS_DIR ))[1 ]
71- return f"{ str (COMPONENTS_DIR )} { rel_path } "
72- else :
73- return os .path .join (str (COMPONENTS_DIR ), filepath )
74-
75-
76- def _get_components_from_file (filepath : str ) -> List [BuiltinComponent ]:
77- components_path = _to_relative (filepath )
78- components_module_path = _to_module (components_path )
79- module = importlib .import_module (components_module_path )
80- functions = getmembers (module , isfunction )
81- buitin_functions = []
82- for function_name , _ in functions :
83- # Ignore private functions.
84- if function_name .startswith ("_" ):
85- continue
86- component_desc = _get_builtin_description (filepath , function_name )
87- if component_desc :
88- definition = _get_component_definition (
89- components_module_path , function_name
90- )
91- builtin_component = BuiltinComponent (
92- definition = definition ,
93- description = component_desc ,
94- )
95- buitin_functions .append (builtin_component )
96- return buitin_functions
97-
98-
99- def _allowed_path (path : str ) -> bool :
100- filename = os .path .basename (path )
101- if filename .startswith ("_" ):
102- return False
103- return True
104-
105-
106- def _builtins () -> List [BuiltinComponent ]:
107- components_dir = entrypoints .load (
108- "torchx.file" , "get_dir_path" , default = get_abspath
109- )(COMPONENTS_DIR )
110-
111- builtins : List [BuiltinComponent ] = []
112- search_pattern = os .path .join (components_dir , "**" , "*.py" )
113- for filepath in glob .glob (search_pattern , recursive = True ):
114- if not _allowed_path (filepath ):
115- continue
116- components = _get_components_from_file (filepath )
117- builtins += components
118- return builtins
119-
120-
12136class CmdBuiltins (SubCommand ):
12237 def add_arguments (self , subparser : argparse .ArgumentParser ) -> None :
123- pass # no arguments
38+ pass
39+
40+ def _builtins (self ) -> Dict [str , _Component ]:
41+ return get_components ()
12442
12543 def run (self , args : argparse .Namespace ) -> None :
126- builtin_configs = _builtins ()
127- num_builtins = len (builtin_configs )
44+ builtin_components = self . _builtins ()
45+ num_builtins = len (builtin_components )
12846 print (f"Found { num_builtins } builtin configs:" )
129- for i , component in enumerate (builtin_configs ):
130- print (f" { i + 1 :2d} . { component .definition } - { component . description } " )
47+ for i , component in enumerate (builtin_components . values () ):
48+ print (f" { i + 1 :2d} . { component .name } " )
13149
13250
13351class CmdRun (SubCommand ):
@@ -172,15 +90,23 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
17290 def run (self , args : argparse .Namespace ) -> None :
17391 # TODO: T91790598 - remove the if condition when all apps are migrated to pure python
17492 runner = get_runner ()
175- app_handle = runner .run_from_path (
93+ result = runner .run_component (
17694 args .conf_file ,
17795 args .conf_args ,
17896 args .scheduler ,
17997 args .scheduler_args ,
18098 dryrun = args .dryrun ,
18199 )
182100
183- if not args .dryrun :
101+ if args .dryrun :
102+ app_dryrun_info = cast (specs .AppDryRunInfo , result )
103+ print ("=== APPLICATION ===" )
104+ print (pformat (asdict (app_dryrun_info ._app ), indent = 2 , width = 80 ))
105+
106+ print ("=== SCHEDULER REQUEST ===" )
107+ print (app_dryrun_info )
108+ else :
109+ app_handle = cast (specs .AppHandle , result )
184110 if args .scheduler == "local" :
185111 runner .wait (app_handle )
186112 else :
0 commit comments