Skip to content

Commit 82f76cf

Browse files
authored
Better type annotations (#2981)
Use unqouted types in type annotations. Fix some incorrect annotations.
1 parent d1bdebd commit 82f76cf

File tree

1 file changed

+35
-25
lines changed

1 file changed

+35
-25
lines changed

python/sdist/amici/swig.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,26 @@ class TypeHintFixer(ast.NodeTransformer):
1818
"size_t": ast.Name("int"),
1919
"bool": ast.Name("bool"),
2020
"boolean": ast.Name("bool"),
21-
"std::unique_ptr< amici::Solver >": ast.Constant("Solver"),
22-
"amici::InternalSensitivityMethod": ast.Constant(
21+
"std::unique_ptr< amici::Solver >": ast.Name("Solver"),
22+
"amici::InternalSensitivityMethod": ast.Name(
2323
"InternalSensitivityMethod"
2424
),
25-
"amici::InterpolationType": ast.Constant("InterpolationType"),
26-
"amici::LinearMultistepMethod": ast.Constant("LinearMultistepMethod"),
27-
"amici::LinearSolver": ast.Constant("LinearSolver"),
28-
"amici::Model *": ast.Constant("Model"),
29-
"amici::Model const *": ast.Constant("Model"),
30-
"amici::NewtonDampingFactorMode": ast.Constant(
31-
"NewtonDampingFactorMode"
32-
),
33-
"amici::NonlinearSolverIteration": ast.Constant(
25+
"amici::InterpolationType": ast.Name("InterpolationType"),
26+
"amici::LinearMultistepMethod": ast.Name("LinearMultistepMethod"),
27+
"amici::LinearSolver": ast.Name("LinearSolver"),
28+
"amici::Model *": ast.Name("Model"),
29+
"amici::Model const *": ast.Name("Model"),
30+
"amici::NewtonDampingFactorMode": ast.Name("NewtonDampingFactorMode"),
31+
"amici::NonlinearSolverIteration": ast.Name(
3432
"NonlinearSolverIteration"
3533
),
36-
"amici::ObservableScaling": ast.Constant("ObservableScaling"),
37-
"amici::ParameterScaling": ast.Constant("ParameterScaling"),
38-
"amici::RDataReporting": ast.Constant("RDataReporting"),
39-
"amici::SensitivityMethod": ast.Constant("SensitivityMethod"),
40-
"amici::SensitivityOrder": ast.Constant("SensitivityOrder"),
41-
"amici::Solver *": ast.Constant("Solver"),
42-
"amici::SteadyStateSensitivityMode": ast.Constant(
34+
"amici::ObservableScaling": ast.Name("ObservableScaling"),
35+
"amici::ParameterScaling": ast.Name("ParameterScaling"),
36+
"amici::RDataReporting": ast.Name("RDataReporting"),
37+
"amici::SensitivityMethod": ast.Name("SensitivityMethod"),
38+
"amici::SensitivityOrder": ast.Name("SensitivityOrder"),
39+
"amici::Solver *": ast.Name("Solver"),
40+
"amici::SteadyStateSensitivityMode": ast.Name(
4341
"SteadyStateSensitivityMode"
4442
),
4543
"amici::realtype": ast.Name("float"),
@@ -49,15 +47,23 @@ class TypeHintFixer(ast.NodeTransformer):
4947
"StringVector": ast.Name("Sequence[str]"),
5048
"std::string": ast.Name("str"),
5149
"std::string const &": ast.Name("str"),
52-
"std::unique_ptr< amici::ExpData >": ast.Constant("ExpData"),
53-
"std::unique_ptr< amici::ReturnData >": ast.Constant("ReturnData"),
50+
"std::unique_ptr< amici::ExpData >": ast.Name("ExpData"),
51+
"std::unique_ptr< amici::ReturnData >": ast.Name("ReturnData"),
5452
"std::vector< amici::ParameterScaling,"
55-
"std::allocator< amici::ParameterScaling > > const &": ast.Constant(
53+
"std::allocator< amici::ParameterScaling > > const &": ast.Name(
5654
"ParameterScalingVector"
5755
),
58-
"H5::H5File": None,
5956
}
6057

58+
def __init__(self):
59+
super().__init__()
60+
61+
# Add all mapped-to type names to the mapping dict to convert any
62+
# quoted occurrences of those types to unquoted types
63+
for annot in list(self.mapping.values()):
64+
if isinstance(annot, ast.Name):
65+
self.mapping[annot.id] = annot
66+
6167
def visit_FunctionDef(self, node):
6268
# convert type/rtype from docstring to annotation, if possible.
6369
# those may be c++ types, not valid in python, that need to be
@@ -140,7 +146,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
140146
for line_no, line in enumerate(docstring):
141147
if type_str := self.extract_rtype(line):
142148
# handle `:rtype:`
143-
node.returns = ast.Constant(type_str)
149+
node.returns = self.mapping.get(
150+
type_str, ast.Constant(type_str)
151+
)
144152
lines_to_remove.add(line_no)
145153
continue
146154

@@ -149,7 +157,9 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
149157
# handle `:type ...:`
150158
for arg in node.args.args:
151159
if arg.arg == arg_name:
152-
arg.annotation = ast.Constant(type_str)
160+
arg.annotation = self.mapping.get(
161+
type_str, ast.Name(type_str)
162+
)
153163
lines_to_remove.add(line_no)
154164

155165
if lines_to_remove:
@@ -160,7 +170,7 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
160170
for line_no, line in enumerate(docstring)
161171
if line_no not in lines_to_remove
162172
)
163-
node.body[0].value = ast.Str(new_docstring)
173+
node.body[0].value = ast.Constant(new_docstring)
164174

165175
@staticmethod
166176
def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:

0 commit comments

Comments
 (0)