diff --git a/docs/source/specs.rst b/docs/source/specs.rst index 33e40f08f..6fd1ea2c4 100644 --- a/docs/source/specs.rst +++ b/docs/source/specs.rst @@ -90,17 +90,17 @@ Component Linter .. autoclass:: LinterMessage :members: -.. autoclass:: TorchFunctionVisitor +.. autoclass:: ComponentFnVisitor :members: .. autoclass:: TorchXArgumentHelpFormatter :members: -.. autoclass:: TorchxFunctionArgsValidator +.. autoclass:: ArgTypeValidator :members: -.. autoclass:: TorchxFunctionValidator +.. autoclass:: ComponentFunctionValidator :members: -.. autoclass:: TorchxReturnValidator +.. autoclass:: ReturnTypeValidator :members: diff --git a/torchx/specs/file_linter.py b/torchx/specs/file_linter.py index b11abd78f..38c8935ba 100644 --- a/torchx/specs/file_linter.py +++ b/torchx/specs/file_linter.py @@ -11,8 +11,9 @@ import argparse import ast import inspect +import sys from dataclasses import dataclass -from typing import Callable, cast, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from docstring_parser import parse from torchx.util.io import read_conf_file @@ -98,7 +99,7 @@ class LinterMessage: severity: str = "error" -class TorchxFunctionValidator(abc.ABC): +class ComponentFunctionValidator(abc.ABC): @abc.abstractmethod def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: """ @@ -116,7 +117,55 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage: ) -class TorchxFunctionArgsValidator(TorchxFunctionValidator): +def OK() -> list[LinterMessage]: + return [] # empty linter error means validation passes + + +def is_primitive(arg: ast.expr) -> bool: + # whether the arg is a primitive type (e.g. int, float, str, bool) + return isinstance(arg, ast.Name) + + +def get_generic_type(arg: ast.expr) -> ast.expr: + # returns the slice expr of a subscripted type + # `arg` must be an instance of ast.Subscript (caller checks) + # in this validator's context, this is the generic type of a container type + # e.g. for Optional[str] returns the expr for str + + assert isinstance(arg, ast.Subscript) # e.g. arg = C[T] + + if isinstance(arg.slice, ast.Index): # python>=3.10 + return arg.slice.value + else: # python-3.9 + return arg.slice + + +def get_optional_type(arg: ast.expr) -> Optional[ast.expr]: + """ + Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg`` + is not an ``Optional``. Handles both: + 1. ``typing.Optional[T]`` (python<3.10) + 2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604) + """ + # case 1: 'a: Optional[T]' + if isinstance(arg, ast.Subscript) and arg.value.id == "Optional": + return get_generic_type(arg) + + # case 2: 'a: T | None' or 'a: None | T' + if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10 + if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr): + if isinstance(arg.right, ast.Constant) and arg.right.value is None: + return arg.left + if isinstance(arg.left, ast.Constant) and arg.left.value is None: + return arg.right + + # case 3: is not optional + return None + + +class ArgTypeValidator(ComponentFunctionValidator): + """Validates component function's argument types.""" + def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: linter_errors = [] for arg_def in app_specs_func_def.args.args: @@ -133,53 +182,68 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: return linter_errors def _validate_arg_def( - self, function_name: str, arg_def: ast.arg + self, function_name: str, arg: ast.arg ) -> List[LinterMessage]: - if not arg_def.annotation: - return [ - self._gen_linter_message( - f"Arg {arg_def.arg} missing type annotation", arg_def.lineno - ) - ] - if isinstance(arg_def.annotation, ast.Name): + arg_type = arg.annotation # type hint + + def ok() -> list[LinterMessage]: + # return value when validation passes (e.g. no linter errors) return [] - complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation)) - if complex_type_def.value.id == "Optional": - # ast module in python3.9 does not have ast.Index wrapper - if isinstance(complex_type_def.slice, ast.Index): - complex_type_def = complex_type_def.slice.value - else: - complex_type_def = complex_type_def.slice - # Check if type is Optional[primitive_type] - if isinstance(complex_type_def, ast.Name): - return [] - # Check if type is Union[Dict,List] - type_name = complex_type_def.value.id - if type_name != "Dict" and type_name != "List": - desc = ( - f"`{function_name}` allows only Dict, List as complex types." - f"Argument `{arg_def.arg}` has: {type_name}" - ) - return [self._gen_linter_message(desc, arg_def.lineno)] - linter_errors = [] - # ast module in python3.9 does not have objects wrapped in ast.Index - if isinstance(complex_type_def.slice, ast.Index): - sub_type = complex_type_def.slice.value + + def err(reason: str) -> list[LinterMessage]: + msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}" + return [self._gen_linter_message(msg, arg.lineno)] + + if not arg_type: + return err("Missing type annotation") + + # Case 1: optional + if T := get_optional_type(arg_type): + # NOTE: optional types can be primitives or any of the allowed container types + # so check if arg is an optional, and if so, run the rest of the validation with the unpacked type + arg_type = T + + # Case 2: int, float, str, bool + if is_primitive(arg_type): + return ok() + # Case 3: Containers (Dict, List, Tuple) + elif isinstance(arg_type, ast.Subscript): + container_type = arg_type.value.id + + if container_type in ["Dict", "dict"]: + KV = get_generic_type(arg_type) + + assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice + + K, V = KV.elts + if not is_primitive(K): + return err(f"Non-primitive key type {ast.unparse(K)!r}") + if not is_primitive(V): + return err(f"Non-primitive value type {ast.unparse(V)!r}") + return ok() + elif container_type in ["List", "list"]: + T = get_generic_type(arg_type) + if is_primitive(T): + return ok() + else: + return err(f"Non-primitive element type {ast.unparse(T)!r}") + elif container_type in ["Tuple", "tuple"]: + E_N = get_generic_type(arg_type) + assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice + + for e in E_N.elts: + if not is_primitive(e): + return err(f"Non-primitive element type '{ast.unparse(e)!r}'") + + return ok() + + return err(f"Unsupported container type {container_type!r}") else: - sub_type = complex_type_def.slice - if type_name == "Dict": - sub_type_tuple = cast(ast.Tuple, sub_type) - for el in sub_type_tuple.elts: - if not isinstance(el, ast.Name): - desc = "Dict can only have primitive types" - linter_errors.append(self._gen_linter_message(desc, arg_def.lineno)) - elif not isinstance(sub_type, ast.Name): - desc = "List can only have primitive types" - linter_errors.append(self._gen_linter_message(desc, arg_def.lineno)) - return linter_errors + return err(f"Unsupported argument type {ast.unparse(arg_type)!r}") -class TorchxReturnValidator(TorchxFunctionValidator): +class ReturnTypeValidator(ComponentFunctionValidator): + """Validates that component functions always return AppDef type""" def __init__(self, supported_return_type: str) -> None: super().__init__() @@ -231,7 +295,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: return linter_errors -class TorchFunctionVisitor(ast.NodeVisitor): +class ComponentFnVisitor(ast.NodeVisitor): """ Visitor that finds the component_function and runs registered validators on it. Current registered validators: @@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor): def __init__( self, component_function_name: str, - validators: Optional[List[TorchxFunctionValidator]], + validators: Optional[List[ComponentFunctionValidator]], ) -> None: if validators is None: - self.validators: List[TorchxFunctionValidator] = [ - TorchxFunctionArgsValidator(), - TorchxReturnValidator("AppDef"), + self.validators: List[ComponentFunctionValidator] = [ + ArgTypeValidator(), + ReturnTypeValidator("AppDef"), ] else: self.validators = validators @@ -279,7 +343,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: def validate( path: str, component_function: str, - validators: Optional[List[TorchxFunctionValidator]], + validators: Optional[List[ComponentFunctionValidator]] = None, ) -> List[LinterMessage]: """ Validates the function to make sure it complies the component standard. @@ -309,7 +373,7 @@ def validate( severity="error", ) return [linter_message] - visitor = TorchFunctionVisitor(component_function, validators) + visitor = ComponentFnVisitor(component_function, validators) visitor.visit(module) linter_errors = visitor.linter_errors if not visitor.visited_function: diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index 66900e7a4..def4699b4 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -19,7 +19,11 @@ from types import ModuleType from typing import Any, Callable, Dict, Generator, List, Optional, Union -from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate +from torchx.specs.file_linter import ( + ComponentFunctionValidator, + get_fn_docstring, + validate, +) from torchx.util import entrypoints from torchx.util.io import read_conf_file from torchx.util.types import none_throws @@ -64,7 +68,7 @@ class _Component: class ComponentsFinder(abc.ABC): @abc.abstractmethod def find( - self, validators: Optional[List[TorchxFunctionValidator]] + self, validators: Optional[List[ComponentFunctionValidator]] ) -> List[_Component]: """ Retrieves a set of components. A component is defined as a python @@ -210,7 +214,7 @@ def _iter_modules_recursive( yield self._try_import(module_info.name) def find( - self, validators: Optional[List[TorchxFunctionValidator]] + self, validators: Optional[List[ComponentFunctionValidator]] ) -> List[_Component]: components = [] for m in self._iter_modules_recursive(self.base_module): @@ -230,7 +234,7 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType: return module def _get_components_from_module( - self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]] + self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]] ) -> List[_Component]: functions = getmembers(module, isfunction) component_defs = [] @@ -269,7 +273,7 @@ def _get_validation_errors( self, path: str, function_name: str, - validators: Optional[List[TorchxFunctionValidator]], + validators: Optional[List[ComponentFunctionValidator]], ) -> List[str]: linter_errors = validate(path, function_name, validators) return [linter_error.description for linter_error in linter_errors] @@ -289,7 +293,7 @@ def _get_path_to_function_decl( return path_to_function_decl def find( - self, validators: Optional[List[TorchxFunctionValidator]] + self, validators: Optional[List[ComponentFunctionValidator]] ) -> List[_Component]: file_source = read_conf_file(self._filepath) @@ -321,7 +325,7 @@ def find( def _load_custom_components( - validators: Optional[List[TorchxFunctionValidator]], + validators: Optional[List[ComponentFunctionValidator]], ) -> List[_Component]: component_modules = { name: load_fn() @@ -346,7 +350,7 @@ def _load_custom_components( def _load_components( - validators: Optional[List[TorchxFunctionValidator]], + validators: Optional[List[ComponentFunctionValidator]], ) -> Dict[str, _Component]: """ Loads either the custom component defs from the entrypoint ``[torchx.components]`` @@ -368,7 +372,7 @@ def _load_components( def _find_components( - validators: Optional[List[TorchxFunctionValidator]], + validators: Optional[List[ComponentFunctionValidator]], ) -> Dict[str, _Component]: global _components if not _components: @@ -381,7 +385,7 @@ def _is_custom_component(component_name: str) -> bool: def _find_custom_components( - name: str, validators: Optional[List[TorchxFunctionValidator]] + name: str, validators: Optional[List[ComponentFunctionValidator]] ) -> Dict[str, _Component]: if ":" not in name: raise ValueError( @@ -393,7 +397,7 @@ def _find_custom_components( def get_components( - validators: Optional[List[TorchxFunctionValidator]] = None, + validators: Optional[List[ComponentFunctionValidator]] = None, ) -> Dict[str, _Component]: """ Returns all custom components registered via ``[torchx.components]`` entrypoints @@ -448,7 +452,7 @@ def get_components( def get_component( - name: str, validators: Optional[List[TorchxFunctionValidator]] = None + name: str, validators: Optional[List[ComponentFunctionValidator]] = None ) -> _Component: """ Retrieves components by the provided name. @@ -477,7 +481,7 @@ def get_component( def get_builtin_source( - name: str, validators: Optional[List[TorchxFunctionValidator]] = None + name: str, validators: Optional[List[ComponentFunctionValidator]] = None ) -> str: """ Returns a string of the the builtin component's function source code diff --git a/torchx/specs/test/file_linter_test.py b/torchx/specs/test/file_linter_test.py index 5428ce95c..b23363d92 100644 --- a/torchx/specs/test/file_linter_test.py +++ b/torchx/specs/test/file_linter_test.py @@ -5,24 +5,31 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[2] +# arguments untyped in certain functions for testing +# flake8: noqa: B950 import argparse import os +import sys import unittest from typing import Dict, List, Optional from unittest.mock import patch +from torchx.specs import AppDef + from torchx.specs.file_linter import ( get_fn_docstring, TorchXArgumentHelpFormatter, validate, ) +IGNORED = AppDef(name="__IGNORED__") + # Note if the function is moved, the tests need to be updated with new lineno -# pyre-ignore[11]: Ignore unknown type "AppDef" -def _test_empty_fn() -> "AppDef": - pass +def _test_empty_fn() -> AppDef: + return IGNORED # Note if the function is moved, the tests need to be updated with new lineno @@ -40,7 +47,7 @@ def _test_fn_return_int() -> int: return 0 -def _test_docstring(arg0: str, arg1: int, arg2: Dict[int, str]) -> "AppDef": +def _test_docstring(arg0: str, arg1: int, arg2: Dict[int, str]) -> AppDef: """Short Test description Long funct description @@ -49,20 +56,19 @@ def _test_docstring(arg0: str, arg1: int, arg2: Dict[int, str]) -> "AppDef": arg0: arg0 desc arg1: arg1 desc """ - pass + return IGNORED -def _test_docstring_short() -> "AppDef": +def _test_docstring_short() -> AppDef: """Short Test description""" - pass + return IGNORED -def _test_without_docstring(arg0: str) -> "AppDef": - pass +def _test_without_docstring(arg0: str) -> AppDef: + return IGNORED -# pyre-ignore[2]: Omit return value for testing purposes -def _test_args_no_type_defs(arg0, arg1, arg2: Dict[int, str]) -> "AppDef": +def _test_args_no_type_defs(arg0, arg1, arg2: Dict[int, str]) -> AppDef: """ Test description @@ -71,18 +77,19 @@ def _test_args_no_type_defs(arg0, arg1, arg2: Dict[int, str]) -> "AppDef": arg1: arg1 desc arg2: arg2 desc """ - pass + return IGNORED -def _test_args_dict_list_complex_types( - # pyre-ignore[2]: Omit return value for testing purposes +def _test_args_complex_types( arg0, - # pyre-ignore[2]: Omit return value for testing purposes - arg1, - arg2: Dict[int, List[str]], - arg3: List[List[str]], - arg4: Optional[Optional[str]], -) -> "AppDef": + arg1: Dict[int, List[str]], + arg2: Dict[int, Dict[int, str]], + arg3: Dict[List[int], str], + arg4: Dict[Dict[int, str], str], + arg5: List[List[str]], + arg6: List[Dict[str, str]], + arg7: Optional[Optional[str]], +) -> AppDef: """ Test description @@ -92,18 +99,56 @@ def _test_args_dict_list_complex_types( arg2: arg2 desc arg3: arg2 desc """ - pass + return IGNORED -# pyre-ignore[2] -def _test_invalid_fn_with_varags_and_kwargs(*args, id: int) -> "AppDef": +def _test_args_builtin_complex_types( + arg0, + arg1: dict[int, list[str]], + arg2: dict[int, dict[int, str]], + arg3: dict[list[int], str], + arg4: dict[dict[int, str], str], + arg5: list[list[str]], + arg6: list[dict[str, str]], + arg7: Optional[Optional[str]], +) -> AppDef: + """ + Test description + + Args: + arg0: arg0 desc + arg1: arg1 desc + arg2: arg2 desc + arg3: arg2 desc + """ + return IGNORED + + +if sys.version_info >= (3, 10): + + def _test_args_optional_types( + arg0: int | None, + arg1: None | int, + arg2: dict[str, str] | None, + arg3: list[str] | None, + arg4: tuple[str, str] | None, + arg5: Optional[int], + arg6: Optional[dict[str, str]], + ) -> AppDef: + """ + Test both ways to specify optional for python-3.10+ + """ + return IGNORED + + +def _test_invalid_fn_with_varags_and_kwargs(*args, id: int) -> AppDef: """ Test description Args: args: args desc """ - pass + return IGNORED def current_file_path() -> str: @@ -121,21 +166,19 @@ def test_syntax_error(self) -> None: content = "!!foo====bar" with patch("torchx.specs.file_linter.read_conf_file") as read_conf_file_mock: read_conf_file_mock.return_value = content - errors = validate(self._path, "unknown_function", None) + errors = validate(self._path, "unknown_function") self.assertEqual(1, len(errors)) self.assertEqual("invalid syntax", errors[0].description) def test_validate_varargs_kwargs_fn(self) -> None: - linter_errors = validate( - self._path, "_test_invalid_fn_with_varags_and_kwargs", None - ) + linter_errors = validate(self._path, "_test_invalid_fn_with_varags_and_kwargs") self.assertEqual(1, len(linter_errors)) self.assertTrue( "Arg args missing type annotation", linter_errors[0].description ) def test_validate_no_return(self) -> None: - linter_errors = validate(self._path, "_test_fn_no_return", None) + linter_errors = validate(self._path, "_test_fn_no_return") self.assertEqual(1, len(linter_errors)) expected_desc = ( "Function: _test_fn_no_return missing return annotation or " @@ -144,7 +187,7 @@ def test_validate_no_return(self) -> None: self.assertEqual(expected_desc, linter_errors[0].description) def test_validate_incorrect_return(self) -> None: - linter_errors = validate(self._path, "_test_fn_return_int", None) + linter_errors = validate(self._path, "_test_fn_return_int") self.assertEqual(1, len(linter_errors)) expected_desc = ( "Function: _test_fn_return_int has incorrect return annotation, " @@ -165,40 +208,65 @@ def test_no_validators_has_no_validation(self) -> None: self.assertEqual(0, len(linter_errors)) def test_validate_empty_fn(self) -> None: - linter_errors = validate(self._path, "_test_empty_fn", None) + linter_errors = validate(self._path, "_test_empty_fn") self.assertEqual(0, len(linter_errors)) def test_validate_args_no_type_defs(self) -> None: - linter_errors = validate(self._path, "_test_args_no_type_defs", None) - print(linter_errors) - self.assertEqual(2, len(linter_errors)) - self.assertEqual( - "Arg arg0 missing type annotation", linter_errors[0].description - ) - self.assertEqual( - "Arg arg1 missing type annotation", linter_errors[1].description + fn = "_test_args_no_type_defs" + linter_errors = validate(self._path, fn) + error_msgs = [e.description for e in linter_errors] + + self.assertListEqual( + [ + "Missing type annotation for argument 'arg0' in function '_test_args_no_type_defs'", + "Missing type annotation for argument 'arg1' in function '_test_args_no_type_defs'", + ], + error_msgs, ) - def test_validate_args_no_type_defs_complex(self) -> None: - linter_errors = validate(self._path, "_test_args_dict_list_complex_types", None) - self.assertEqual(5, len(linter_errors)) - self.assertEqual( - "Arg arg0 missing type annotation", linter_errors[0].description + def test_validate_args_complex_types(self) -> None: + linter_errors = validate(self._path, "_test_args_complex_types") + error_msgs = [e.description for e in linter_errors] + self.assertListEqual( + [ + "Missing type annotation for argument 'arg0' in function '_test_args_complex_types'", + "Non-primitive value type 'List[str]' for argument 'arg1: Dict[int, List[str]]' in function '_test_args_complex_types'", + "Non-primitive value type 'Dict[int, str]' for argument 'arg2: Dict[int, Dict[int, str]]' in function '_test_args_complex_types'", + "Non-primitive key type 'List[int]' for argument 'arg3: Dict[List[int], str]' in function '_test_args_complex_types'", + "Non-primitive key type 'Dict[int, str]' for argument 'arg4: Dict[Dict[int, str], str]' in function '_test_args_complex_types'", + "Non-primitive element type 'List[str]' for argument 'arg5: List[List[str]]' in function '_test_args_complex_types'", + "Non-primitive element type 'Dict[str, str]' for argument 'arg6: List[Dict[str, str]]' in function '_test_args_complex_types'", + "Unsupported container type 'Optional' for argument 'arg7: Optional[Optional[str]]' in function '_test_args_complex_types'", + ], + error_msgs, ) - self.assertEqual( - "Arg arg1 missing type annotation", linter_errors[1].description - ) - self.assertEqual( - "Dict can only have primitive types", linter_errors[2].description - ) - self.assertEqual( - "List can only have primitive types", linter_errors[3].description - ) - self.assertEqual( - "`_test_args_dict_list_complex_types` allows only Dict, List as complex types.Argument `arg4` has: Optional", - linter_errors[4].description, + + def test_validate_args_builtin_complex_types(self) -> None: + linter_errors = validate(self._path, "_test_args_builtin_complex_types") + error_msgs = [e.description for e in linter_errors] + self.assertListEqual( + [ + "Missing type annotation for argument 'arg0' in function '_test_args_builtin_complex_types'", + "Non-primitive value type 'list[str]' for argument 'arg1: dict[int, list[str]]' in function '_test_args_builtin_complex_types'", + "Non-primitive value type 'dict[int, str]' for argument 'arg2: dict[int, dict[int, str]]' in function '_test_args_builtin_complex_types'", + "Non-primitive key type 'list[int]' for argument 'arg3: dict[list[int], str]' in function '_test_args_builtin_complex_types'", + "Non-primitive key type 'dict[int, str]' for argument 'arg4: dict[dict[int, str], str]' in function '_test_args_builtin_complex_types'", + "Non-primitive element type 'list[str]' for argument 'arg5: list[list[str]]' in function '_test_args_builtin_complex_types'", + "Non-primitive element type 'dict[str, str]' for argument 'arg6: list[dict[str, str]]' in function '_test_args_builtin_complex_types'", + "Unsupported container type 'Optional' for argument 'arg7: Optional[Optional[str]]' in function '_test_args_builtin_complex_types'", + ], + error_msgs, ) + # pyre-ignore[56] + @unittest.skipUnless( + sys.version_info >= (3, 10), + "typing optional as [type]|None requires python-3.10+", + ) + def test_validate_args_optional_type(self) -> None: + linter_errors = validate(self._path, "_test_args_optional_types") + self.assertFalse(linter_errors) + def test_validate_docstring(self) -> None: func_desc, param_desc = get_fn_docstring(_test_docstring) self.assertEqual("Short Test description\nLong funct description", func_desc) @@ -218,7 +286,7 @@ def test_validate_docstring_no_docs(self) -> None: self.assertEqual(" ", param_desc["arg0"]) def test_validate_unknown_function(self) -> None: - linter_errors = validate(self._path, "unknown_function", None) + linter_errors = validate(self._path, "unknown_function") self.assertEqual(1, len(linter_errors)) self.assertEqual( "Function unknown_function not found", linter_errors[0].description