1111import argparse
1212import ast
1313import inspect
14+ import sys
1415from dataclasses import dataclass
1516from typing import Callable , cast , Dict , List , Optional , Tuple
1617
@@ -98,7 +99,7 @@ class LinterMessage:
9899 severity : str = "error"
99100
100101
101- class TorchxFunctionValidator (abc .ABC ):
102+ class ComponentFunctionValidator (abc .ABC ):
102103 @abc .abstractmethod
103104 def validate (self , app_specs_func_def : ast .FunctionDef ) -> List [LinterMessage ]:
104105 """
@@ -116,7 +117,55 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
116117 )
117118
118119
119- class TorchxFunctionArgsValidator (TorchxFunctionValidator ):
120+ def OK () -> list [LinterMessage ]:
121+ return [] # empty linter error means validation passes
122+
123+
124+ def is_primitive (arg : ast .expr ) -> bool :
125+ # whether the arg is a primitive type (e.g. int, float, str, bool)
126+ return isinstance (arg , ast .Name )
127+
128+
129+ def get_generic_type (arg : ast .expr ) -> ast .expr :
130+ # returns the slice expr of a subscripted type
131+ # `arg` must be an instance of ast.Subscript (caller checks)
132+ # in this validator's context, this is the generic type of a container type
133+ # e.g. for Optional[str] returns the expr for str
134+
135+ assert isinstance (arg , ast .Subscript ) # e.g. arg = C[T]
136+
137+ if isinstance (arg .slice , ast .Index ): # python>=3.10
138+ return arg .slice .value
139+ else : # python-3.9
140+ return arg .slice
141+
142+
143+ def get_optional_type (arg : ast .expr ) -> Optional [ast .expr ]:
144+ """
145+ Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146+ is not an ``Optional``. Handles both:
147+ 1. ``typing.Optional[T]`` (python<3.10)
148+ 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149+ """
150+ # case 1: 'a: Optional[T]'
151+ if isinstance (arg , ast .Subscript ) and arg .value .id == "Optional" :
152+ return get_generic_type (arg )
153+
154+ # case 2: 'a: T | None' or 'a: None | T'
155+ if sys .version_info >= (3 , 10 ): # PEP 604 introduced in python-3.10
156+ if isinstance (arg , ast .BinOp ) and isinstance (arg .op , ast .BitOr ):
157+ if isinstance (arg .right , ast .Constant ) and arg .right .value is None :
158+ return arg .left
159+ if isinstance (arg .left , ast .Constant ) and arg .left .value is None :
160+ return arg .right
161+
162+ # case 3: is not optional
163+ return None
164+
165+
166+ class ArgTypeValidator (ComponentFunctionValidator ):
167+ """Validates component function's argument types."""
168+
120169 def validate (self , app_specs_func_def : ast .FunctionDef ) -> List [LinterMessage ]:
121170 linter_errors = []
122171 for arg_def in app_specs_func_def .args .args :
@@ -133,53 +182,68 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
133182 return linter_errors
134183
135184 def _validate_arg_def (
136- self , function_name : str , arg_def : ast .arg
185+ self , function_name : str , arg : ast .arg
137186 ) -> List [LinterMessage ]:
138- if not arg_def .annotation :
139- return [
140- self ._gen_linter_message (
141- f"Arg { arg_def .arg } missing type annotation" , arg_def .lineno
142- )
143- ]
144- if isinstance (arg_def .annotation , ast .Name ):
187+ arg_type = arg .annotation # type hint
188+
189+ def ok () -> list [LinterMessage ]:
190+ # return value when validation passes (e.g. no linter errors)
145191 return []
146- complex_type_def = cast (ast .Subscript , none_throws (arg_def .annotation ))
147- if complex_type_def .value .id == "Optional" :
148- # ast module in python3.9 does not have ast.Index wrapper
149- if isinstance (complex_type_def .slice , ast .Index ):
150- complex_type_def = complex_type_def .slice .value
151- else :
152- complex_type_def = complex_type_def .slice
153- # Check if type is Optional[primitive_type]
154- if isinstance (complex_type_def , ast .Name ):
155- return []
156- # Check if type is Union[Dict,List]
157- type_name = complex_type_def .value .id
158- if type_name != "Dict" and type_name != "List" :
159- desc = (
160- f"`{ function_name } ` allows only Dict, List as complex types."
161- f"Argument `{ arg_def .arg } ` has: { type_name } "
162- )
163- return [self ._gen_linter_message (desc , arg_def .lineno )]
164- linter_errors = []
165- # ast module in python3.9 does not have objects wrapped in ast.Index
166- if isinstance (complex_type_def .slice , ast .Index ):
167- sub_type = complex_type_def .slice .value
192+
193+ def err (reason : str ) -> list [LinterMessage ]:
194+ msg = f"{ reason } for argument { ast .unparse (arg )!r} in function { function_name !r} "
195+ return [self ._gen_linter_message (msg , arg .lineno )]
196+
197+ if not arg_type :
198+ return err ("Missing type annotation" )
199+
200+ # Case 1: optional
201+ if T := get_optional_type (arg_type ):
202+ # NOTE: optional types can be primitives or any of the allowed container types
203+ # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204+ arg_type = T
205+
206+ # Case 2: int, float, str, bool
207+ if is_primitive (arg_type ):
208+ return ok ()
209+ # Case 3: Containers (Dict, List, Tuple)
210+ elif isinstance (arg_type , ast .Subscript ):
211+ container_type = arg_type .value .id
212+
213+ if container_type in ["Dict" , "dict" ]:
214+ KV = get_generic_type (arg_type )
215+
216+ assert isinstance (KV , ast .Tuple ) # dict[K,V] has ast.Tuple slice
217+
218+ K , V = KV .elts
219+ if not is_primitive (K ):
220+ return err (f"Non-primitive key type { ast .unparse (K )!r} " )
221+ if not is_primitive (V ):
222+ return err (f"Non-primitive value type { ast .unparse (V )!r} " )
223+ return ok ()
224+ elif container_type in ["List" , "list" ]:
225+ T = get_generic_type (arg_type )
226+ if is_primitive (T ):
227+ return ok ()
228+ else :
229+ return err (f"Non-primitive element type { ast .unparse (T )!r} " )
230+ elif container_type in ["Tuple" , "tuple" ]:
231+ E_N = get_generic_type (arg_type )
232+ assert isinstance (E_N , ast .Tuple ) # tuple[...] has ast.Tuple slice
233+
234+ for e in E_N .elts :
235+ if not is_primitive (e ):
236+ return err (f"Non-primitive element type '{ ast .unparse (e )!r} '" )
237+
238+ return ok ()
239+
240+ return err (f"Unsupported container type { container_type !r} " )
168241 else :
169- sub_type = complex_type_def .slice
170- if type_name == "Dict" :
171- sub_type_tuple = cast (ast .Tuple , sub_type )
172- for el in sub_type_tuple .elts :
173- if not isinstance (el , ast .Name ):
174- desc = "Dict can only have primitive types"
175- linter_errors .append (self ._gen_linter_message (desc , arg_def .lineno ))
176- elif not isinstance (sub_type , ast .Name ):
177- desc = "List can only have primitive types"
178- linter_errors .append (self ._gen_linter_message (desc , arg_def .lineno ))
179- return linter_errors
242+ return err (f"Unsupported argument type { ast .unparse (arg_type )!r} " )
180243
181244
182- class TorchxReturnValidator (TorchxFunctionValidator ):
245+ class ReturnTypeValidator (ComponentFunctionValidator ):
246+ """Validates that component functions always return AppDef type"""
183247
184248 def __init__ (self , supported_return_type : str ) -> None :
185249 super ().__init__ ()
@@ -231,7 +295,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
231295 return linter_errors
232296
233297
234- class TorchFunctionVisitor (ast .NodeVisitor ):
298+ class ComponentFnVisitor (ast .NodeVisitor ):
235299 """
236300 Visitor that finds the component_function and runs registered validators on it.
237301 Current registered validators:
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
252316 def __init__ (
253317 self ,
254318 component_function_name : str ,
255- validators : Optional [List [TorchxFunctionValidator ]],
319+ validators : Optional [List [ComponentFunctionValidator ]],
256320 ) -> None :
257321 if validators is None :
258- self .validators : List [TorchxFunctionValidator ] = [
259- TorchxFunctionArgsValidator (),
260- TorchxReturnValidator ("AppDef" ),
322+ self .validators : List [ComponentFunctionValidator ] = [
323+ ArgTypeValidator (),
324+ ReturnTypeValidator ("AppDef" ),
261325 ]
262326 else :
263327 self .validators = validators
@@ -279,7 +343,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
279343def validate (
280344 path : str ,
281345 component_function : str ,
282- validators : Optional [List [TorchxFunctionValidator ]] ,
346+ validators : Optional [List [ComponentFunctionValidator ]] = None ,
283347) -> List [LinterMessage ]:
284348 """
285349 Validates the function to make sure it complies the component standard.
@@ -309,7 +373,7 @@ def validate(
309373 severity = "error" ,
310374 )
311375 return [linter_message ]
312- visitor = TorchFunctionVisitor (component_function , validators )
376+ visitor = ComponentFnVisitor (component_function , validators )
313377 visitor .visit (module )
314378 linter_errors = visitor .linter_errors
315379 if not visitor .visited_function :
0 commit comments