diff --git a/torchx/specs/file_linter.py b/torchx/specs/file_linter.py index 90b1df064..b11abd78f 100644 --- a/torchx/specs/file_linter.py +++ b/torchx/specs/file_linter.py @@ -180,6 +180,11 @@ def _validate_arg_def( class TorchxReturnValidator(TorchxFunctionValidator): + + def __init__(self, supported_return_type: str) -> None: + super().__init__() + self._supported_return_type = supported_return_type + def _get_return_annotation( self, app_specs_func_def: ast.FunctionDef ) -> Optional[str]: @@ -203,7 +208,7 @@ def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: * AppDef * specs.AppDef """ - supported_return_annotation = "AppDef" + supported_return_annotation = self._supported_return_type return_annotation = self._get_return_annotation(app_specs_func_def) linter_errors = [] if not return_annotation: @@ -252,7 +257,7 @@ def __init__( if validators is None: self.validators: List[TorchxFunctionValidator] = [ TorchxFunctionArgsValidator(), - TorchxReturnValidator(), + TorchxReturnValidator("AppDef"), ] else: self.validators = validators