@@ -18,28 +18,26 @@ class TypeHintFixer(ast.NodeTransformer):
1818 "size_t" : ast .Name ("int" ),
1919 "bool" : ast .Name ("bool" ),
2020 "boolean" : ast .Name ("bool" ),
21- "std::unique_ptr< amici::Solver >" : ast .Constant ("Solver" ),
22- "amici::InternalSensitivityMethod" : ast .Constant (
21+ "std::unique_ptr< amici::Solver >" : ast .Name ("Solver" ),
22+ "amici::InternalSensitivityMethod" : ast .Name (
2323 "InternalSensitivityMethod"
2424 ),
25- "amici::InterpolationType" : ast .Constant ("InterpolationType" ),
26- "amici::LinearMultistepMethod" : ast .Constant ("LinearMultistepMethod" ),
27- "amici::LinearSolver" : ast .Constant ("LinearSolver" ),
28- "amici::Model *" : ast .Constant ("Model" ),
29- "amici::Model const *" : ast .Constant ("Model" ),
30- "amici::NewtonDampingFactorMode" : ast .Constant (
31- "NewtonDampingFactorMode"
32- ),
33- "amici::NonlinearSolverIteration" : ast .Constant (
25+ "amici::InterpolationType" : ast .Name ("InterpolationType" ),
26+ "amici::LinearMultistepMethod" : ast .Name ("LinearMultistepMethod" ),
27+ "amici::LinearSolver" : ast .Name ("LinearSolver" ),
28+ "amici::Model *" : ast .Name ("Model" ),
29+ "amici::Model const *" : ast .Name ("Model" ),
30+ "amici::NewtonDampingFactorMode" : ast .Name ("NewtonDampingFactorMode" ),
31+ "amici::NonlinearSolverIteration" : ast .Name (
3432 "NonlinearSolverIteration"
3533 ),
36- "amici::ObservableScaling" : ast .Constant ("ObservableScaling" ),
37- "amici::ParameterScaling" : ast .Constant ("ParameterScaling" ),
38- "amici::RDataReporting" : ast .Constant ("RDataReporting" ),
39- "amici::SensitivityMethod" : ast .Constant ("SensitivityMethod" ),
40- "amici::SensitivityOrder" : ast .Constant ("SensitivityOrder" ),
41- "amici::Solver *" : ast .Constant ("Solver" ),
42- "amici::SteadyStateSensitivityMode" : ast .Constant (
34+ "amici::ObservableScaling" : ast .Name ("ObservableScaling" ),
35+ "amici::ParameterScaling" : ast .Name ("ParameterScaling" ),
36+ "amici::RDataReporting" : ast .Name ("RDataReporting" ),
37+ "amici::SensitivityMethod" : ast .Name ("SensitivityMethod" ),
38+ "amici::SensitivityOrder" : ast .Name ("SensitivityOrder" ),
39+ "amici::Solver *" : ast .Name ("Solver" ),
40+ "amici::SteadyStateSensitivityMode" : ast .Name (
4341 "SteadyStateSensitivityMode"
4442 ),
4543 "amici::realtype" : ast .Name ("float" ),
@@ -49,15 +47,23 @@ class TypeHintFixer(ast.NodeTransformer):
4947 "StringVector" : ast .Name ("Sequence[str]" ),
5048 "std::string" : ast .Name ("str" ),
5149 "std::string const &" : ast .Name ("str" ),
52- "std::unique_ptr< amici::ExpData >" : ast .Constant ("ExpData" ),
53- "std::unique_ptr< amici::ReturnData >" : ast .Constant ("ReturnData" ),
50+ "std::unique_ptr< amici::ExpData >" : ast .Name ("ExpData" ),
51+ "std::unique_ptr< amici::ReturnData >" : ast .Name ("ReturnData" ),
5452 "std::vector< amici::ParameterScaling,"
55- "std::allocator< amici::ParameterScaling > > const &" : ast .Constant (
53+ "std::allocator< amici::ParameterScaling > > const &" : ast .Name (
5654 "ParameterScalingVector"
5755 ),
58- "H5::H5File" : None ,
5956 }
6057
58+ def __init__ (self ):
59+ super ().__init__ ()
60+
61+ # Add all mapped-to type names to the mapping dict to convert any
62+ # quoted occurrences of those types to unquoted types
63+ for annot in list (self .mapping .values ()):
64+ if isinstance (annot , ast .Name ):
65+ self .mapping [annot .id ] = annot
66+
6167 def visit_FunctionDef (self , node ):
6268 # convert type/rtype from docstring to annotation, if possible.
6369 # those may be c++ types, not valid in python, that need to be
@@ -140,7 +146,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
140146 for line_no , line in enumerate (docstring ):
141147 if type_str := self .extract_rtype (line ):
142148 # handle `:rtype:`
143- node .returns = ast .Constant (type_str )
149+ node .returns = self .mapping .get (
150+ type_str , ast .Constant (type_str )
151+ )
144152 lines_to_remove .add (line_no )
145153 continue
146154
@@ -149,7 +157,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
149157 # handle `:type ...:`
150158 for arg in node .args .args :
151159 if arg .arg == arg_name :
152- arg .annotation = ast .Constant (type_str )
160+ arg .annotation = self .mapping .get (
161+ type_str , ast .Name (type_str )
162+ )
153163 lines_to_remove .add (line_no )
154164
155165 if lines_to_remove :
@@ -160,7 +170,7 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
160170 for line_no , line in enumerate (docstring )
161171 if line_no not in lines_to_remove
162172 )
163- node .body [0 ].value = ast .Str (new_docstring )
173+ node .body [0 ].value = ast .Constant (new_docstring )
164174
165175 @staticmethod
166176 def extract_type (line : str ) -> tuple [str , str ] | tuple [None , None ]:
0 commit comments