Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/specs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ Component Linter
.. autoclass:: LinterMessage
:members:

.. autoclass:: TorchFunctionVisitor
.. autoclass:: ComponentFnVisitor
:members:

.. autoclass:: TorchXArgumentHelpFormatter
:members:

.. autoclass:: TorchxFunctionArgsValidator
.. autoclass:: ArgTypeValidator
:members:

.. autoclass:: TorchxFunctionValidator
.. autoclass:: ComponentFunctionValidator
:members:

.. autoclass:: TorchxReturnValidator
.. autoclass:: ReturnTypeValidator
:members:
168 changes: 116 additions & 52 deletions torchx/specs/file_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import argparse
import ast
import inspect
import sys
from dataclasses import dataclass
from typing import Callable, cast, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

from docstring_parser import parse
from torchx.util.io import read_conf_file
Expand Down Expand Up @@ -98,7 +99,7 @@ class LinterMessage:
severity: str = "error"


class TorchxFunctionValidator(abc.ABC):
class ComponentFunctionValidator(abc.ABC):
@abc.abstractmethod
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
"""
Expand All @@ -116,7 +117,55 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
)


class TorchxFunctionArgsValidator(TorchxFunctionValidator):
def OK() -> list[LinterMessage]:
return [] # empty linter error means validation passes


def is_primitive(arg: ast.expr) -> bool:
# whether the arg is a primitive type (e.g. int, float, str, bool)
return isinstance(arg, ast.Name)


def get_generic_type(arg: ast.expr) -> ast.expr:
# returns the slice expr of a subscripted type
# `arg` must be an instance of ast.Subscript (caller checks)
# in this validator's context, this is the generic type of a container type
# e.g. for Optional[str] returns the expr for str

assert isinstance(arg, ast.Subscript) # e.g. arg = C[T]

if isinstance(arg.slice, ast.Index): # python>=3.10
return arg.slice.value
else: # python-3.9
return arg.slice


def get_optional_type(arg: ast.expr) -> Optional[ast.expr]:
"""
Returns the type parameter ``T`` of ``Optional[T]`` or ``None`` if `arg``
is not an ``Optional``. Handles both:
1. ``typing.Optional[T]`` (python<3.10)
2. ``T | None`` or ``None | T`` (python>=3.10 - PEP 604)
"""
# case 1: 'a: Optional[T]'
if isinstance(arg, ast.Subscript) and arg.value.id == "Optional":
return get_generic_type(arg)

# case 2: 'a: T | None' or 'a: None | T'
if sys.version_info >= (3, 10): # PEP 604 introduced in python-3.10
if isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.BitOr):
if isinstance(arg.right, ast.Constant) and arg.right.value is None:
return arg.left
if isinstance(arg.left, ast.Constant) and arg.left.value is None:
return arg.right

# case 3: is not optional
return None


class ArgTypeValidator(ComponentFunctionValidator):
"""Validates component function's argument types."""

def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
linter_errors = []
for arg_def in app_specs_func_def.args.args:
Expand All @@ -133,53 +182,68 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
return linter_errors

def _validate_arg_def(
self, function_name: str, arg_def: ast.arg
self, function_name: str, arg: ast.arg
) -> List[LinterMessage]:
if not arg_def.annotation:
return [
self._gen_linter_message(
f"Arg {arg_def.arg} missing type annotation", arg_def.lineno
)
]
if isinstance(arg_def.annotation, ast.Name):
arg_type = arg.annotation # type hint

def ok() -> list[LinterMessage]:
# return value when validation passes (e.g. no linter errors)
return []
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
if complex_type_def.value.id == "Optional":
# ast module in python3.9 does not have ast.Index wrapper
if isinstance(complex_type_def.slice, ast.Index):
complex_type_def = complex_type_def.slice.value
else:
complex_type_def = complex_type_def.slice
# Check if type is Optional[primitive_type]
if isinstance(complex_type_def, ast.Name):
return []
# Check if type is Union[Dict,List]
type_name = complex_type_def.value.id
if type_name != "Dict" and type_name != "List":
desc = (
f"`{function_name}` allows only Dict, List as complex types."
f"Argument `{arg_def.arg}` has: {type_name}"
)
return [self._gen_linter_message(desc, arg_def.lineno)]
linter_errors = []
# ast module in python3.9 does not have objects wrapped in ast.Index
if isinstance(complex_type_def.slice, ast.Index):
sub_type = complex_type_def.slice.value

def err(reason: str) -> list[LinterMessage]:
msg = f"{reason} for argument {ast.unparse(arg)!r} in function {function_name!r}"
return [self._gen_linter_message(msg, arg.lineno)]

if not arg_type:
return err("Missing type annotation")

# Case 1: optional
if T := get_optional_type(arg_type):
# NOTE: optional types can be primitives or any of the allowed container types
# so check if arg is an optional, and if so, run the rest of the validation with the unpacked type
arg_type = T

# Case 2: int, float, str, bool
if is_primitive(arg_type):
return ok()
# Case 3: Containers (Dict, List, Tuple)
elif isinstance(arg_type, ast.Subscript):
container_type = arg_type.value.id

if container_type in ["Dict", "dict"]:
KV = get_generic_type(arg_type)

assert isinstance(KV, ast.Tuple) # dict[K,V] has ast.Tuple slice

K, V = KV.elts
if not is_primitive(K):
return err(f"Non-primitive key type {ast.unparse(K)!r}")
if not is_primitive(V):
return err(f"Non-primitive value type {ast.unparse(V)!r}")
return ok()
elif container_type in ["List", "list"]:
T = get_generic_type(arg_type)
if is_primitive(T):
return ok()
else:
return err(f"Non-primitive element type {ast.unparse(T)!r}")
elif container_type in ["Tuple", "tuple"]:
E_N = get_generic_type(arg_type)
assert isinstance(E_N, ast.Tuple) # tuple[...] has ast.Tuple slice

for e in E_N.elts:
if not is_primitive(e):
return err(f"Non-primitive element type '{ast.unparse(e)!r}'")

return ok()

return err(f"Unsupported container type {container_type!r}")
else:
sub_type = complex_type_def.slice
if type_name == "Dict":
sub_type_tuple = cast(ast.Tuple, sub_type)
for el in sub_type_tuple.elts:
if not isinstance(el, ast.Name):
desc = "Dict can only have primitive types"
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
elif not isinstance(sub_type, ast.Name):
desc = "List can only have primitive types"
linter_errors.append(self._gen_linter_message(desc, arg_def.lineno))
return linter_errors
return err(f"Unsupported argument type {ast.unparse(arg_type)!r}")


class TorchxReturnValidator(TorchxFunctionValidator):
class ReturnTypeValidator(ComponentFunctionValidator):
"""Validates that component functions always return AppDef type"""

def __init__(self, supported_return_type: str) -> None:
super().__init__()
Expand Down Expand Up @@ -231,7 +295,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
return linter_errors


class TorchFunctionVisitor(ast.NodeVisitor):
class ComponentFnVisitor(ast.NodeVisitor):
"""
Visitor that finds the component_function and runs registered validators on it.
Current registered validators:
Expand All @@ -252,12 +316,12 @@ class TorchFunctionVisitor(ast.NodeVisitor):
def __init__(
self,
component_function_name: str,
validators: Optional[List[TorchxFunctionValidator]],
validators: Optional[List[ComponentFunctionValidator]],
) -> None:
if validators is None:
self.validators: List[TorchxFunctionValidator] = [
TorchxFunctionArgsValidator(),
TorchxReturnValidator("AppDef"),
self.validators: List[ComponentFunctionValidator] = [
ArgTypeValidator(),
ReturnTypeValidator("AppDef"),
]
else:
self.validators = validators
Expand All @@ -279,7 +343,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
def validate(
path: str,
component_function: str,
validators: Optional[List[TorchxFunctionValidator]],
validators: Optional[List[ComponentFunctionValidator]] = None,
) -> List[LinterMessage]:
"""
Validates the function to make sure it complies the component standard.
Expand Down Expand Up @@ -309,7 +373,7 @@ def validate(
severity="error",
)
return [linter_message]
visitor = TorchFunctionVisitor(component_function, validators)
visitor = ComponentFnVisitor(component_function, validators)
visitor.visit(module)
linter_errors = visitor.linter_errors
if not visitor.visited_function:
Expand Down
30 changes: 17 additions & 13 deletions torchx/specs/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from types import ModuleType
from typing import Any, Callable, Dict, Generator, List, Optional, Union

from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
from torchx.specs.file_linter import (
ComponentFunctionValidator,
get_fn_docstring,
validate,
)
from torchx.util import entrypoints
from torchx.util.io import read_conf_file
from torchx.util.types import none_throws
Expand Down Expand Up @@ -64,7 +68,7 @@ class _Component:
class ComponentsFinder(abc.ABC):
@abc.abstractmethod
def find(
self, validators: Optional[List[TorchxFunctionValidator]]
self, validators: Optional[List[ComponentFunctionValidator]]
) -> List[_Component]:
"""
Retrieves a set of components. A component is defined as a python
Expand Down Expand Up @@ -210,7 +214,7 @@ def _iter_modules_recursive(
yield self._try_import(module_info.name)

def find(
self, validators: Optional[List[TorchxFunctionValidator]]
self, validators: Optional[List[ComponentFunctionValidator]]
) -> List[_Component]:
components = []
for m in self._iter_modules_recursive(self.base_module):
Expand All @@ -230,7 +234,7 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
return module

def _get_components_from_module(
self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
) -> List[_Component]:
functions = getmembers(module, isfunction)
component_defs = []
Expand Down Expand Up @@ -269,7 +273,7 @@ def _get_validation_errors(
self,
path: str,
function_name: str,
validators: Optional[List[TorchxFunctionValidator]],
validators: Optional[List[ComponentFunctionValidator]],
) -> List[str]:
linter_errors = validate(path, function_name, validators)
return [linter_error.description for linter_error in linter_errors]
Expand All @@ -289,7 +293,7 @@ def _get_path_to_function_decl(
return path_to_function_decl

def find(
self, validators: Optional[List[TorchxFunctionValidator]]
self, validators: Optional[List[ComponentFunctionValidator]]
) -> List[_Component]:

file_source = read_conf_file(self._filepath)
Expand Down Expand Up @@ -321,7 +325,7 @@ def find(


def _load_custom_components(
validators: Optional[List[TorchxFunctionValidator]],
validators: Optional[List[ComponentFunctionValidator]],
) -> List[_Component]:
component_modules = {
name: load_fn()
Expand All @@ -346,7 +350,7 @@ def _load_custom_components(


def _load_components(
validators: Optional[List[TorchxFunctionValidator]],
validators: Optional[List[ComponentFunctionValidator]],
) -> Dict[str, _Component]:
"""
Loads either the custom component defs from the entrypoint ``[torchx.components]``
Expand All @@ -368,7 +372,7 @@ def _load_components(


def _find_components(
validators: Optional[List[TorchxFunctionValidator]],
validators: Optional[List[ComponentFunctionValidator]],
) -> Dict[str, _Component]:
global _components
if not _components:
Expand All @@ -381,7 +385,7 @@ def _is_custom_component(component_name: str) -> bool:


def _find_custom_components(
name: str, validators: Optional[List[TorchxFunctionValidator]]
name: str, validators: Optional[List[ComponentFunctionValidator]]
) -> Dict[str, _Component]:
if ":" not in name:
raise ValueError(
Expand All @@ -393,7 +397,7 @@ def _find_custom_components(


def get_components(
validators: Optional[List[TorchxFunctionValidator]] = None,
validators: Optional[List[ComponentFunctionValidator]] = None,
) -> Dict[str, _Component]:
"""
Returns all custom components registered via ``[torchx.components]`` entrypoints
Expand Down Expand Up @@ -448,7 +452,7 @@ def get_components(


def get_component(
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
) -> _Component:
"""
Retrieves components by the provided name.
Expand Down Expand Up @@ -477,7 +481,7 @@ def get_component(


def get_builtin_source(
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
) -> str:
"""
Returns a string of the the builtin component's function source code
Expand Down
Loading
Loading