77from __future__ import annotations
88
99import re
10+ import sys
11+ import warnings
1012from collections .abc import Iterator
11- from typing import Callable , Final , Literal
13+ from dataclasses import dataclass
14+ from enum import StrEnum
15+ from typing import Any , Callable , Final , Literal
1216
1317import luigi
1418from mypy .expandtype import expand_type
1519from mypy .nodes import (
20+ ARG_NAMED ,
1621 ARG_NAMED_OPT ,
22+ ArgKind ,
1723 Argument ,
1824 AssignmentStmt ,
1925 Block ,
3238 TypeInfo ,
3339 Var ,
3440)
41+ from mypy .options import Options
3542from mypy .plugin import ClassDefContext , FunctionContext , Plugin , SemanticAnalyzerPluginInterface
3643from mypy .plugins .common import (
3744 add_method_to_class ,
5663PARAMETER_TMP_MATCHER : Final = re .compile (r'^\w*Parameter$' )
5764
5865
66+ class PluginOptions (StrEnum ):
67+ DISALLOW_MISSING_PARAMETERS = 'disallow_missing_parameters'
68+
69+
70+ @dataclass
71+ class TaskOnKartPluginOptions :
72+ # Whether to error on missing parameters in the constructor.
73+ # Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor.
74+ disallow_missing_parameters : bool = False
75+
76+ @classmethod
77+ def _parse_toml (cls , config_file : str ) -> dict [str , Any ]:
78+ if sys .version_info >= (3 , 11 ):
79+ import tomllib as toml_
80+ else :
81+ try :
82+ import tomli as toml_
83+ except ImportError : # pragma: no cover
84+ warnings .warn ('install tomli to parse pyproject.toml under Python 3.10' , stacklevel = 1 )
85+ return {}
86+
87+ with open (config_file , 'rb' ) as f :
88+ return toml_ .load (f )
89+
90+ @classmethod
91+ def parse_config_file (cls , config_file : str ) -> TaskOnKartPluginOptions :
92+ # TODO: support other configuration file formats if necessary.
93+ if not config_file .endswith ('.toml' ):
94+ warnings .warn ('gokart mypy plugin can be configured by pyproject.toml' , stacklevel = 1 )
95+ return cls ()
96+
97+ config = cls ._parse_toml (config_file )
98+ gokart_plugin_config = config .get ('tool' , {}).get ('gokart-mypy' , {})
99+
100+ disallow_missing_parameters = gokart_plugin_config .get (PluginOptions .DISALLOW_MISSING_PARAMETERS .value , False )
101+ if not isinstance (disallow_missing_parameters , bool ):
102+ raise ValueError (f'{ PluginOptions .DISALLOW_MISSING_PARAMETERS .value } must be a boolean value' )
103+ return cls (disallow_missing_parameters = disallow_missing_parameters )
104+
105+
59106class TaskOnKartPlugin (Plugin ):
107+ def __init__ (self , options : Options ) -> None :
108+ super ().__init__ (options )
109+ if options .config_file is not None :
110+ self ._options = TaskOnKartPluginOptions .parse_config_file (options .config_file )
111+ else :
112+ self ._options = TaskOnKartPluginOptions ()
113+
60114 def get_base_class_hook (self , fullname : str ) -> Callable [[ClassDefContext ], None ] | None :
61115 # The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
62116 # the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
@@ -78,7 +132,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
78132 return None
79133
80134 def _task_on_kart_class_maker_callback (self , ctx : ClassDefContext ) -> None :
81- transformer = TaskOnKartTransformer (ctx .cls , ctx .reason , ctx .api )
135+ transformer = TaskOnKartTransformer (ctx .cls , ctx .reason , ctx .api , self . _options )
82136 transformer .transform ()
83137
84138 def _task_on_kart_parameter_field_callback (self , ctx : FunctionContext ) -> Type :
@@ -125,6 +179,7 @@ def __init__(
125179 type : Type | None ,
126180 info : TypeInfo ,
127181 api : SemanticAnalyzerPluginInterface ,
182+ options : TaskOnKartPluginOptions ,
128183 ) -> None :
129184 self .name = name
130185 self .has_default = has_default
@@ -133,12 +188,12 @@ def __init__(
133188 self .type = type # Type as __init__ argument
134189 self .info = info
135190 self ._api = api
191+ self ._options = options
136192
137193 def to_argument (self , current_info : TypeInfo , * , of : Literal ['__init__' ,]) -> Argument :
138194 if of == '__init__' :
139- # All arguments to __init__ are keyword-only and optional
140- # This is because gokart can set parameters by configuration'
141- arg_kind = ARG_NAMED_OPT
195+ arg_kind = self ._get_arg_kind_by_options ()
196+
142197 return Argument (
143198 variable = self .to_var (current_info ),
144199 type_annotation = self .expand_type (current_info ),
@@ -170,10 +225,10 @@ def serialize(self) -> JsonDict:
170225 }
171226
172227 @classmethod
173- def deserialize (cls , info : TypeInfo , data : JsonDict , api : SemanticAnalyzerPluginInterface ) -> TaskOnKartAttribute :
228+ def deserialize (cls , info : TypeInfo , data : JsonDict , api : SemanticAnalyzerPluginInterface , options : TaskOnKartPluginOptions ) -> TaskOnKartAttribute :
174229 data = data .copy ()
175230 typ = deserialize_and_fixup_type (data .pop ('type' ), api )
176- return cls (type = typ , info = info , ** data , api = api )
231+ return cls (type = typ , info = info , ** data , api = api , options = options )
177232
178233 def expand_typevar_from_subtype (self , sub_type : TypeInfo ) -> None :
179234 """Expands type vars in the context of a subtype when an attribute is inherited
@@ -182,6 +237,22 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
182237 with state .strict_optional_set (self ._api .options .strict_optional ):
183238 self .type = map_type_from_supertype (self .type , sub_type , self .info )
184239
240+ def _get_arg_kind_by_options (self ) -> Literal [ArgKind .ARG_NAMED , ArgKind .ARG_NAMED_OPT ]:
241+ """Set the argument kind based on the options.
242+
243+ if `disallow_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value.
244+ This means the that all the parameters are passed to the constructor as keyword-only arguments.
245+
246+ Returns:
247+ Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind.
248+ """
249+ if not self ._options .disallow_missing_parameters :
250+ return ARG_NAMED_OPT
251+ if self .has_default :
252+ return ARG_NAMED_OPT
253+ # required parameter
254+ return ARG_NAMED
255+
185256
186257class TaskOnKartTransformer :
187258 """Implement the behavior of gokart.TaskOnKart."""
@@ -191,10 +262,12 @@ def __init__(
191262 cls : ClassDef ,
192263 reason : Expression | Statement ,
193264 api : SemanticAnalyzerPluginInterface ,
265+ options : TaskOnKartPluginOptions ,
194266 ) -> None :
195267 self ._cls = cls
196268 self ._reason = reason
197269 self ._api = api
270+ self ._options = options
198271
199272 def transform (self ) -> bool :
200273 """Apply all the necessary transformations to the underlying gokart.TaskOnKart"""
@@ -267,7 +340,7 @@ def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
267340 for data in info .metadata [METADATA_TAG ]['attributes' ]:
268341 name : str = data ['name' ]
269342
270- attr = TaskOnKartAttribute .deserialize (info , data , self ._api )
343+ attr = TaskOnKartAttribute .deserialize (info , data , self ._api , self . _options )
271344 # TODO: We shouldn't be performing type operations during the main
272345 # semantic analysis pass, since some TypeInfo attributes might
273346 # still be in flux. This should be performed in a later phase.
@@ -337,6 +410,7 @@ def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
337410 type = init_type ,
338411 info = cls .info ,
339412 api = self ._api ,
413+ options = self ._options ,
340414 )
341415
342416 return list (found_attrs .values ())
0 commit comments