1111import argparse
1212import ast
1313import inspect
14+ import sys
1415from dataclasses import dataclass
1516from typing import Callable, cast, Dict, List, Optional, Tuple
1617
@@ -98,7 +99,7 @@ class LinterMessage:
9899 severity: str = "error"
99100
100101
101- class TorchxFunctionValidator (abc.ABC):
102+ class ComponentFunctionValidator (abc.ABC):
102103 @abc.abstractmethod
103104 def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
104105 """
@@ -116,7 +117,55 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
116117 )
117118
118119
119- class TorchxFunctionArgsValidator(TorchxFunctionValidator):
120+ def OK() -> list[LinterMessage]:
121+ return [] # empty linter error means validation passes
122+
123+
124+ def is_primitive(arg: ast.expr) -> bool:
125+ # whether the arg is a primitive type (e.g. int, float, str, bool)
126+ return isinstance(arg, ast.Name)
127+
128+
129+ def get_generic_type(arg: ast.expr) -> ast.expr:
130+ # returns the slice expr of a subscripted type
131+ # `arg` must be an instance of ast.Subscript (caller checks)
132+ # in this validator's context, this is the generic type of a container type
133+ # e.g. for Optional[str] returns the expr for str
134+
135+ assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]
136+
137+ if isinstance(arg.slice, ast.Index): # python>=3.10
138+ return arg.slice.value
139+ else: # python-3.9
140+ return arg.slice
141+
142+
143+ def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
144+ """
145+ Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
146+ is not an ``Optional``. Handles both:
147+ 1. ``typing.Optional[T]`` (python<3.10)
148+ 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
149+ """
150+ # case 1: 'a: Optional[T]'
151+ if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
152+ return get_generic_type(arg)
153+
154+ # case 2: 'a: T | None' or 'a: None | T'
155+ if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
156+ if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
157+ if isinstance(arg.right, ast.Constant) and arg.right.value is None:
158+ return arg.left
159+ if isinstance(arg.left, ast.Constant) and arg.left.value is None:
160+ return arg.right
161+
162+ # case 3: is not optional
163+ return None
164+
165+
166+ class ArgTypeValidator(ComponentFunctionValidator):
167+ """Validates component function's argument types."""
168+
120169 def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
121170 linter_errors = []
122171 for arg_def in app_specs_func_def.args.args:
@@ -133,53 +182,68 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
133182 return linter_errors
134183
135184 def _validate_arg_def(
136- self, function_name: str, arg_def : ast.arg
185+ self, function_name: str, arg : ast.arg
137186 ) -> List[LinterMessage]:
138- if not arg_def.annotation:
139- return [
140- self._gen_linter_message(
141- f"Arg {arg_def.arg} missing type annotation", arg_def.lineno
142- )
143- ]
144- if isinstance(arg_def.annotation, ast.Name):
187+ arg_type = arg.annotation # type hint
188+
189+ def ok() -> list[LinterMessage]:
190+ # return value when validation passes (e.g. no linter errors)
145191 return []
146- complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
147- if complex_type_def.value.id == "Optional":
148- # ast module in python3.9 does not have ast.Index wrapper
149- if isinstance(complex_type_def.slice, ast.Index):
150- complex_type_def = complex_type_def.slice.value
151- else:
152- complex_type_def = complex_type_def.slice
153- # Check if type is Optional[primitive_type]
154- if isinstance(complex_type_def, ast.Name):
155- return []
156- # Check if type is Union[Dict,List]
157- type_name = complex_type_def.value.id
158- if type_name != "Dict" and type_name != "List":
159- desc = (
160- f"`{function_name}` allows only Dict, List as complex types."
161- f"Argument `{arg_def.arg}` has: {type_name}"
162- )
163- return [self._gen_linter_message(desc, arg_def.lineno)]
164- linter_errors = []
165- # ast module in python3.9 does not have objects wrapped in ast.Index
166- if isinstance(complex_type_def.slice, ast.Index):
167- sub_type = complex_type_def.slice.value
192+
193+ def err(reason: str) -> list[LinterMessage]:
194+ msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
195+ return [self._gen_linter_message(msg, arg.lineno)]
196+
197+ if not arg_type:
198+ return err("Missing type annotation")
199+
200+ # Case 1: optional
201+ if T := get_optional_type(arg_type):
202+ # NOTE: optional types can be primitives or any of the allowed container types
203+ # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
204+ arg_type = T
205+
206+ # Case 2: int, float, str, bool
207+ if is_primitive(arg_type):
208+ return ok()
209+ # Case 3: Containers (Dict, List, Tuple)
210+ elif isinstance(arg_type, ast.Subscript):
211+ container_type = arg_type.value.id
212+
213+ if container_type in ["Dict", "dict"]:
214+ KV = get_generic_type(arg_type)
215+
216+ assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice
217+
218+ K, V = KV.elts
219+ if not is_primitive(K):
220+ return err(f"Non-primitive key type {ast.unparse(K)!r}")
221+ if not is_primitive(V):
222+ return err(f"Non-primitive value type {ast.unparse(V)!r}")
223+ return ok()
224+ elif container_type in ["List", "list"]:
225+ T = get_generic_type(arg_type)
226+ if is_primitive(T):
227+ return ok()
228+ else:
229+ return err(f"Non-primitive element type {ast.unparse(T)!r}")
230+ elif container_type in ["Tuple", "tuple"]:
231+ E_N = get_generic_type(arg_type)
232+ assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice
233+
234+ for e in E_N.elts:
235+ if not is_primitive(e):
236+ return err(f"Non-primitive element type '{ast.unparse(e)!r}'")
237+
238+ return ok()
239+
240+ return err(f"Unsupported container type {container_type!r}")
168241 else:
169- sub_type = complex_type_def.slice
170- if type_name == "Dict":
171- sub_type_tuple = cast(ast.Tuple, sub_type)
172- for el in sub_type_tuple.elts:
173- if not isinstance(el, ast.Name):
174- desc = "Dict can only have primitive types"
175- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
176- elif not isinstance(sub_type, ast.Name):
177- desc = "List can only have primitive types"
178- linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
179- return linter_errors
242+ return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")
180243
181244
182- class TorchxReturnValidator(TorchxFunctionValidator):
245+ class ReturnTypeValidator(ComponentFunctionValidator):
246+ """Validates that component functions always return AppDef type"""
183247
184248 def __init__(self, supported_return_type: str) -> None:
185249 super().__init__()
@@ -231,7 +295,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
231295 return linter_errors
232296
233297
234- class TorchFunctionVisitor (ast.NodeVisitor):
298+ class ComponentFnVisitor (ast.NodeVisitor):
235299 """
236300 Visitor that finds the component_function and runs registered validators on it.
237301 Current registered validators:
@@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
252316 def __init__(
253317 self,
254318 component_function_name: str,
255- validators: Optional[List[TorchxFunctionValidator ]],
319+ validators: Optional[List[ComponentFunctionValidator ]],
256320 ) -> None:
257321 if validators is None:
258- self.validators: List[TorchxFunctionValidator ] = [
259- TorchxFunctionArgsValidator (),
260- TorchxReturnValidator ("AppDef"),
322+ self.validators: List[ComponentFunctionValidator ] = [
323+ ArgTypeValidator (),
324+ ReturnTypeValidator ("AppDef"),
261325 ]
262326 else:
263327 self.validators = validators
@@ -279,7 +343,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
279343def validate(
280344 path: str,
281345 component_function: str,
282- validators: Optional[List[TorchxFunctionValidator]] ,
346+ validators: Optional[List[ComponentFunctionValidator]] = None ,
283347) -> List[LinterMessage]:
284348 """
285349 Validates the function to make sure it complies the component standard.
@@ -309,7 +373,7 @@ def validate(
309373 severity="error",
310374 )
311375 return [linter_message]
312- visitor = TorchFunctionVisitor (component_function, validators)
376+ visitor = ComponentFnVisitor (component_function, validators)
313377 visitor.visit(module)
314378 linter_errors = visitor.linter_errors
315379 if not visitor.visited_function:
0 commit comments