Skip to content
This repository was archived by the owner on Jan 13, 2026. It is now read-only.

Commit cda5724

Browse files
committed
More consistent and robust __convert_type() function
1 parent b80a1ed commit cda5724

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

rewrite/rewrite/python/_parser_visitor.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,19 @@ def visit_UnaryOp(self, node):
17881788
)
17891789

17901790
def __convert_type(self, node) -> Optional[TypeTree]:
1791+
prefix = self.__whitespace()
1792+
converted_type = self.__convert_internal(node, self.__convert_type, self.__convert_type_mapper)
1793+
if is_of_type(converted_type, TypeTree):
1794+
return converted_type.with_prefix(prefix)
1795+
else:
1796+
return py.ExpressionTypeTree(
1797+
random_id(),
1798+
prefix,
1799+
Markers.EMPTY,
1800+
converted_type
1801+
)
1802+
1803+
def __convert_type_mapper(self, node) -> Optional[TypeTree]:
17911804
if isinstance(node, ast.Constant):
17921805
if node.value is None or node.value is Ellipsis:
17931806
return py.LiteralType(
@@ -1833,7 +1846,6 @@ def __convert_type(self, node) -> Optional[TypeTree]:
18331846
enumerate(slices)],
18341847
Markers.EMPTY
18351848
),
1836-
None,
18371849
None
18381850
)
18391851
elif isinstance(node, ast.BinOp):
@@ -1851,38 +1863,20 @@ def __convert_type(self, node) -> Optional[TypeTree]:
18511863
self.__map_type(node)
18521864
)
18531865

1854-
prefix = self.__whitespace()
1855-
converted_type = self.__convert_internal(node, self.__convert_type)
1856-
if is_of_type(converted_type, TypeTree):
1857-
return converted_type.with_prefix(prefix)
1858-
elif isinstance(converted_type, j.Literal):
1859-
return py.LiteralType(
1860-
random_id(),
1861-
prefix,
1862-
Markers.EMPTY,
1863-
converted_type,
1864-
self.__map_type(node)
1865-
)
1866-
else:
1867-
return py.ExpressionTypeTree(
1868-
random_id(),
1869-
prefix,
1870-
Markers.EMPTY,
1871-
converted_type
1872-
)
1866+
return self.__convert_internal(node, self.__convert_type)
18731867

18741868
def __convert(self, node) -> Optional[J]:
18751869
return self.__convert_internal(node, self.__convert)
18761870

1877-
def __convert_internal(self, node, recursion) -> Optional[J]:
1871+
def __convert_internal(self, node, recursion, mapping = None) -> Optional[J]:
18781872
if not node or not isinstance(node, ast.expr) or isinstance(node, ast.GeneratorExp):
18791873
return self.visit(cast(ast.AST, node)) if node else None
18801874

18811875
save_cursor = self._cursor
18821876
prefix = self.__whitespace()
18831877

18841878
# Handle normal expression or parenthesized expression
1885-
result = self._parse_expr(node, recursion, save_cursor, prefix)
1879+
result = self.__parse_expr(node, mapping or self.visit, recursion, save_cursor, prefix)
18861880

18871881
save_cursor_2 = self._cursor
18881882
suffix = self.__whitespace()
@@ -1903,17 +1897,17 @@ def __convert_internal(self, node, recursion) -> Optional[J]:
19031897
self._cursor = save_cursor_2
19041898
return result
19051899

1906-
def _parse_expr(self, node, recursion, save_cursor: int, prefix: str) -> J:
1900+
def __parse_expr(self, node, mapping, recursion, save_cursor: int, prefix: Space) -> J:
19071901
"""Parse either a normal expression or a parenthesized expression."""
19081902
if not (self._cursor < len(self._source) and self._source[self._cursor] == '('):
19091903
self._cursor = save_cursor
1910-
return self.visit(cast(ast.AST, node))
1904+
return mapping(cast(ast.AST, node))
19111905

19121906
self.__push_parentheses(node, prefix, save_cursor)
19131907

19141908
return recursion(node)
19151909

1916-
def __push_parentheses(self, node, prefix, save_cursor):
1910+
def __push_parentheses(self, node, prefix: Space, save_cursor):
19171911
self._cursor += 1
19181912
expr_prefix = self.__whitespace()
19191913
handler = (

rewrite/tests/python/all/method_declaration_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def test_param_type_with_parens():
130130
rewrite_run(
131131
python(
132132
"""\
133-
def foo(i: (int)) :
133+
from typing import Tuple
134+
def foo(i: (Tuple[int])):
134135
pass
135136
"""
136137
)

0 commit comments

Comments
 (0)