77from __future__ import annotations
88
99import re
10- from typing import Callable , Final , Iterator , Literal , Optional
10+ import sys
11+ import warnings
12+ from dataclasses import dataclass
13+ from typing import Any , Callable , Dict , Final , Iterator , Literal , Optional
1114
1215import luigi
1316from mypy .expandtype import expand_type
1417from mypy .nodes import (
18+ ARG_NAMED ,
1519 ARG_NAMED_OPT ,
20+ ArgKind ,
1621 Argument ,
1722 AssignmentStmt ,
1823 Block ,
3136 TypeInfo ,
3237 Var ,
3338)
39+ from mypy .options import Options
3440from mypy .plugin import ClassDefContext , FunctionContext , Plugin , SemanticAnalyzerPluginInterface
3541from mypy .plugins .common import (
3642 add_method_to_class ,
5561PARAMETER_TMP_MATCHER : Final = re .compile (r'^\w*Parameter$' )
5662
5763
64+ @dataclass
65+ class TaskOnKartPluginOptions :
66+ # Whether to error on missing parameters in the constructor.
67+ # Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor.
68+ error_on_missing_parameters : bool = False
69+
70+ @classmethod
71+ def _parse_toml (cls , config_file : str ) -> Dict [str , Any ]:
72+ if sys .version_info >= (3 , 11 ):
73+ import tomllib as toml_
74+ else :
75+ try :
76+ import tomli as toml_
77+ except ImportError : # pragma: no cover
78+ warnings .warn ('install tomli to parse pyproject.toml under Python 3.10' , stacklevel = 1 )
79+ return {}
80+
81+ with open (config_file , 'rb' ) as f :
82+ return toml_ .load (f )
83+
84+ @classmethod
85+ def parse_config_file (cls , config_file : Optional [str ]) -> 'TaskOnKartPluginOptions' :
86+ if not config_file :
87+ return cls ()
88+
89+ # TODO: support other configuration file formats if necessary.
90+ if not config_file .endswith ('.toml' ):
91+ warnings .warn ('gokart mypy plugin can be configured by pyproject.toml' , stacklevel = 1 )
92+ return cls ()
93+
94+ config = cls ._parse_toml (config_file )
95+ gokart_plugin_config = config .get ('tool' , {}).get ('gokart-mypy' , {})
96+
97+ error_on_missing_parameters = gokart_plugin_config .get ('error_on_missing_parameters' , False )
98+ if not isinstance (error_on_missing_parameters , bool ):
99+ raise ValueError ('error_on_missing_parameters must be a boolean' )
100+ return cls (error_on_missing_parameters = error_on_missing_parameters )
101+
102+
58103class TaskOnKartPlugin (Plugin ):
104+ def __init__ (self , options : Options ) -> None :
105+ super ().__init__ (options )
106+ self ._options = TaskOnKartPluginOptions .parse_config_file (options .config_file )
107+
59108 def get_base_class_hook (self , fullname : str ) -> Callable [[ClassDefContext ], None ] | None :
60109 # The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
61110 # the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
@@ -77,7 +126,7 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
77126 return None
78127
79128 def _task_on_kart_class_maker_callback (self , ctx : ClassDefContext ) -> None :
80- transformer = TaskOnKartTransformer (ctx .cls , ctx .reason , ctx .api )
129+ transformer = TaskOnKartTransformer (ctx .cls , ctx .reason , ctx .api , self . _options )
81130 transformer .transform ()
82131
83132 def _task_on_kart_parameter_field_callback (self , ctx : FunctionContext ) -> Type :
@@ -124,6 +173,7 @@ def __init__(
124173 type : Type | None ,
125174 info : TypeInfo ,
126175 api : SemanticAnalyzerPluginInterface ,
176+ options : TaskOnKartPluginOptions ,
127177 ) -> None :
128178 self .name = name
129179 self .has_default = has_default
@@ -132,12 +182,12 @@ def __init__(
132182 self .type = type # Type as __init__ argument
133183 self .info = info
134184 self ._api = api
185+ self ._options = options
135186
136187 def to_argument (self , current_info : TypeInfo , * , of : Literal ['__init__' ,]) -> Argument :
137188 if of == '__init__' :
138- # All arguments to __init__ are keyword-only and optional
139- # This is because gokart can set parameters by configuration'
140- arg_kind = ARG_NAMED_OPT
189+ arg_kind = self ._get_arg_kind_by_options ()
190+
141191 return Argument (
142192 variable = self .to_var (current_info ),
143193 type_annotation = self .expand_type (current_info ),
@@ -169,10 +219,10 @@ def serialize(self) -> JsonDict:
169219 }
170220
171221 @classmethod
172- def deserialize (cls , info : TypeInfo , data : JsonDict , api : SemanticAnalyzerPluginInterface ) -> TaskOnKartAttribute :
222+ def deserialize (cls , info : TypeInfo , data : JsonDict , api : SemanticAnalyzerPluginInterface , options : TaskOnKartPluginOptions ) -> TaskOnKartAttribute :
173223 data = data .copy ()
174224 typ = deserialize_and_fixup_type (data .pop ('type' ), api )
175- return cls (type = typ , info = info , ** data , api = api )
225+ return cls (type = typ , info = info , ** data , api = api , options = options )
176226
177227 def expand_typevar_from_subtype (self , sub_type : TypeInfo ) -> None :
178228 """Expands type vars in the context of a subtype when an attribute is inherited
@@ -181,6 +231,22 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
181231 with state .strict_optional_set (self ._api .options .strict_optional ):
182232 self .type = map_type_from_supertype (self .type , sub_type , self .info )
183233
234+ def _get_arg_kind_by_options (self ) -> Literal [ArgKind .ARG_NAMED , ArgKind .ARG_NAMED_OPT ]:
235+ """Set the argument kind based on the options.
236+
237+ if `error_on_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value.
238+ This means the that all the parameters are passed to the constructor as keyword-only arguments.
239+
240+ Returns:
241+ Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind.
242+ """
243+ if not self ._options .error_on_missing_parameters :
244+ return ARG_NAMED_OPT
245+ if self .has_default :
246+ return ARG_NAMED_OPT
247+ # required parameter
248+ return ARG_NAMED
249+
184250
185251class TaskOnKartTransformer :
186252 """Implement the behavior of gokart.TaskOnKart."""
@@ -190,10 +256,12 @@ def __init__(
190256 cls : ClassDef ,
191257 reason : Expression | Statement ,
192258 api : SemanticAnalyzerPluginInterface ,
259+ options : TaskOnKartPluginOptions ,
193260 ) -> None :
194261 self ._cls = cls
195262 self ._reason = reason
196263 self ._api = api
264+ self ._options = options
197265
198266 def transform (self ) -> bool :
199267 """Apply all the necessary transformations to the underlying gokart.TaskOnKart"""
@@ -266,7 +334,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
266334 for data in info .metadata [METADATA_TAG ]['attributes' ]:
267335 name : str = data ['name' ]
268336
269- attr = TaskOnKartAttribute .deserialize (info , data , self ._api )
337+ attr = TaskOnKartAttribute .deserialize (info , data , self ._api , self . _options )
270338 # TODO: We shouldn't be performing type operations during the main
271339 # semantic analysis pass, since some TypeInfo attributes might
272340 # still be in flux. This should be performed in a later phase.
@@ -336,6 +404,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
336404 type = init_type ,
337405 info = cls .info ,
338406 api = self ._api ,
407+ options = self ._options ,
339408 )
340409
341410 return list (found_attrs .values ())
0 commit comments