Skip to content

Commit 8d49ec1

Browse files
committed
feat: add gokart mypy options to handle missing parameters
1 parent 2ab1a36 commit 8d49ec1

5 files changed

Lines changed: 141 additions & 10 deletions

File tree

gokart/mypy.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
from __future__ import annotations
88

99
import re
10-
from typing import Callable, Final, Iterator, Literal, Optional
10+
import sys
11+
import warnings
12+
from dataclasses import dataclass
13+
from typing import Any, Callable, Dict, Final, Iterator, Literal, Optional
1114

1215
import luigi
1316
from mypy.expandtype import expand_type
1417
from mypy.nodes import (
18+
ARG_NAMED,
1519
ARG_NAMED_OPT,
20+
ArgKind,
1621
Argument,
1722
AssignmentStmt,
1823
Block,
@@ -31,6 +36,7 @@
3136
TypeInfo,
3237
Var,
3338
)
39+
from mypy.options import Options
3440
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, SemanticAnalyzerPluginInterface
3541
from mypy.plugins.common import (
3642
add_method_to_class,
@@ -55,7 +61,50 @@
5561
PARAMETER_TMP_MATCHER: Final = re.compile(r'^\w*Parameter$')
5662

5763

64+
@dataclass
65+
class TaskOnKartPluginOptions:
66+
# Whether to error on missing parameters in the constructor.
67+
# Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor.
68+
error_on_missing_parameters: bool = False
69+
70+
@classmethod
71+
def _parse_toml(cls, config_file: str) -> Dict[str, Any]:
72+
if sys.version_info >= (3, 11):
73+
import tomllib as toml_
74+
else:
75+
try:
76+
import tomli as toml_
77+
except ImportError: # pragma: no cover
78+
warnings.warn('install tomli to parse pyproject.toml under Python 3.10', stacklevel=1)
79+
return {}
80+
81+
with open(config_file, 'rb') as f:
82+
return toml_.load(f)
83+
84+
@classmethod
85+
def parse_config_file(cls, config_file: Optional[str]) -> 'TaskOnKartPluginOptions':
86+
if not config_file:
87+
return cls()
88+
89+
# TODO: support other configuration file formats if necessary.
90+
if not config_file.endswith('.toml'):
91+
warnings.warn('gokart mypy plugin can be configured by pyproject.toml', stacklevel=1)
92+
return cls()
93+
94+
config = cls._parse_toml(config_file)
95+
gokart_plugin_config = config.get('tool', {}).get('gokart-mypy', {})
96+
97+
error_on_missing_parameters = gokart_plugin_config.get('error_on_missing_parameters', False)
98+
if not isinstance(error_on_missing_parameters, bool):
99+
raise ValueError('error_on_missing_parameters must be a boolean')
100+
return cls(error_on_missing_parameters=error_on_missing_parameters)
101+
102+
58103
class TaskOnKartPlugin(Plugin):
104+
def __init__(self, options: Options) -> None:
105+
super().__init__(options)
106+
self._options = TaskOnKartPluginOptions.parse_config_file(options.config_file)
107+
59108
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
60109
# The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
61110
# the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
@@ -77,7 +126,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
77126
return None
78127

79128
def _task_on_kart_class_maker_callback(self, ctx: ClassDefContext) -> None:
80-
transformer = TaskOnKartTransformer(ctx.cls, ctx.reason, ctx.api)
129+
transformer = TaskOnKartTransformer(ctx.cls, ctx.reason, ctx.api, self._options)
81130
transformer.transform()
82131

83132
def _task_on_kart_parameter_field_callback(self, ctx: FunctionContext) -> Type:
@@ -124,6 +173,7 @@ def __init__(
124173
type: Type | None,
125174
info: TypeInfo,
126175
api: SemanticAnalyzerPluginInterface,
176+
options: TaskOnKartPluginOptions,
127177
) -> None:
128178
self.name = name
129179
self.has_default = has_default
@@ -132,12 +182,12 @@ def __init__(
132182
self.type = type # Type as __init__ argument
133183
self.info = info
134184
self._api = api
185+
self._options = options
135186

136187
def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__',]) -> Argument:
137188
if of == '__init__':
138-
# All arguments to __init__ are keyword-only and optional
139-
# This is because gokart can set parameters by configuration'
140-
arg_kind = ARG_NAMED_OPT
189+
arg_kind = self._get_arg_kind_by_options()
190+
141191
return Argument(
142192
variable=self.to_var(current_info),
143193
type_annotation=self.expand_type(current_info),
@@ -169,10 +219,10 @@ def serialize(self) -> JsonDict:
169219
}
170220

171221
@classmethod
172-
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> TaskOnKartAttribute:
222+
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface, options: TaskOnKartPluginOptions) -> TaskOnKartAttribute:
173223
data = data.copy()
174224
typ = deserialize_and_fixup_type(data.pop('type'), api)
175-
return cls(type=typ, info=info, **data, api=api)
225+
return cls(type=typ, info=info, **data, api=api, options=options)
176226

177227
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
178228
"""Expands type vars in the context of a subtype when an attribute is inherited
@@ -181,6 +231,22 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
181231
with state.strict_optional_set(self._api.options.strict_optional):
182232
self.type = map_type_from_supertype(self.type, sub_type, self.info)
183233

234+
def _get_arg_kind_by_options(self) -> Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]:
235+
"""Set the argument kind based on the options.
236+
237+
if `error_on_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value.
238+
This means the that all the parameters are passed to the constructor as keyword-only arguments.
239+
240+
Returns:
241+
Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind.
242+
"""
243+
if not self._options.error_on_missing_parameters:
244+
return ARG_NAMED_OPT
245+
if self.has_default:
246+
return ARG_NAMED_OPT
247+
# required parameter
248+
return ARG_NAMED
249+
184250

185251
class TaskOnKartTransformer:
186252
"""Implement the behavior of gokart.TaskOnKart."""
@@ -190,10 +256,12 @@ def __init__(
190256
cls: ClassDef,
191257
reason: Expression | Statement,
192258
api: SemanticAnalyzerPluginInterface,
259+
options: TaskOnKartPluginOptions,
193260
) -> None:
194261
self._cls = cls
195262
self._reason = reason
196263
self._api = api
264+
self._options = options
197265

198266
def transform(self) -> bool:
199267
"""Apply all the necessary transformations to the underlying gokart.TaskOnKart"""
@@ -266,7 +334,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
266334
for data in info.metadata[METADATA_TAG]['attributes']:
267335
name: str = data['name']
268336

269-
attr = TaskOnKartAttribute.deserialize(info, data, self._api)
337+
attr = TaskOnKartAttribute.deserialize(info, data, self._api, self._options)
270338
# TODO: We shouldn't be performing type operations during the main
271339
# semantic analysis pass, since some TypeInfo attributes might
272340
# still be in flux. This should be performed in a later phase.
@@ -336,6 +404,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
336404
type=init_type,
337405
info=cls.info,
338406
api=self._api,
407+
options=self._options,
339408
)
340409

341410
return list(found_attrs.values())

test/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33

44
CONFIG_DIR: Final[Path] = Path(__file__).parent.resolve()
55
PYPROJECT_TOML: Final[Path] = CONFIG_DIR / 'pyproject.toml'
6+
PYPROJECT_TOML_SET_ERROR_ON_MISSING_PARAMETERS: Final[Path] = CONFIG_DIR / 'pyproject_error_on_missing_parameters.toml'
67
TEST_CONFIG_INI: Final[Path] = CONFIG_DIR / 'test_config.ini'

test/config/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[tool.mypy]
2-
plugins = ["gokart.mypy:plugin"]
2+
plugins = ["gokart.mypy"]
33

44
[[tool.mypy.overrides]]
55
ignore_missing_imports = true
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[tool.mypy]
2+
plugins = ["gokart.mypy"]
3+
4+
[[tool.mypy.overrides]]
5+
ignore_missing_imports = true
6+
module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"]
7+
8+
[tool.gokart-mypy]
9+
error_on_missing_parameters = true

test/test_mypy.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from mypy import api
55

6-
from test.config import PYPROJECT_TOML
6+
from test.config import PYPROJECT_TOML, PYPROJECT_TOML_SET_ERROR_ON_MISSING_PARAMETERS
77

88

99
class TestMyMypyPlugin(unittest.TestCase):
@@ -122,3 +122,55 @@ class MyTask(gokart.TaskOnKart):
122122
test_file.flush()
123123
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
124124
self.assertIn('Success: no issues found', result[0])
125+
126+
def test_no_issue_found_when_missing_parameter_when_default_option(self):
127+
"""
128+
If `error_on_missing_parameters` is False (or default), mypy doesn't show any error when missing parameters.
129+
"""
130+
test_code = """
131+
import luigi
132+
import gokart
133+
134+
class MyTask(gokart.TaskOnKart):
135+
foo = luigi.IntParameter()
136+
bar = luigi.Parameter(default="bar")
137+
138+
MyTask()
139+
"""
140+
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
141+
test_file.write(test_code.encode('utf-8'))
142+
test_file.flush()
143+
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
144+
self.assertIn('Success: no issues found', result[0])
145+
146+
def test_issue_found_when_missing_parameter_when_error_on_missing_parameters_set_true(self):
147+
"""
148+
If `error_on_missing_parameters` is True, mypy shows an error when missing parameters.
149+
"""
150+
test_code = """
151+
import luigi
152+
import gokart
153+
154+
class MyTask(gokart.TaskOnKart):
155+
# issue: foo is missing
156+
foo = luigi.IntParameter()
157+
# bar has default value, so it is not required to set it.
158+
bar = luigi.Parameter(default="bar")
159+
160+
MyTask()
161+
"""
162+
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
163+
test_file.write(test_code.encode('utf-8'))
164+
test_file.flush()
165+
result = api.run(
166+
[
167+
'--show-traceback',
168+
'--no-incremental',
169+
'--cache-dir=/dev/null',
170+
'--config-file',
171+
str(PYPROJECT_TOML_SET_ERROR_ON_MISSING_PARAMETERS),
172+
test_file.name,
173+
]
174+
)
175+
self.assertIn('error: Missing named argument "foo" for "MyTask" [call-arg]', result[0])
176+
self.assertIn('Found 1 error in 1 file (checked 1 source file)', result[0])

0 commit comments

Comments
 (0)