Skip to content
This repository was archived by the owner on Jan 13, 2026. It is now read-only.

Commit b80a1ed

Browse files
committed
Support types with parens around them
While the parens would be redundant, this is actually allowed in Python. Merge `__convert_type_hint` into `__convert_type`.
1 parent 0f8ded0 commit b80a1ed

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

rewrite/rewrite/python/_parser_visitor.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def map_arg(self, node, default=None, vararg=False, kwarg=False):
138138
vararg_prefix = self.__source_before('*') if vararg else None
139139
name = self.__convert_name(node.arg, self.__map_type(node))
140140
after_name = self.__source_before(':') if node.annotation else Space.EMPTY
141-
type_expression = self.__convert_type_hint(node.annotation) if node.annotation else None
141+
type_expression = self.__convert_type(node.annotation) if node.annotation else None
142142
initializer = self.__pad_left(self.__source_before('='), self.__convert(default)) if default else None
143143

144144
return j.VariableDeclarations(
@@ -297,7 +297,7 @@ def visit_AnnAssign(self, node):
297297
random_id(),
298298
self.__source_before(':'),
299299
Markers.EMPTY,
300-
self.__convert_type_hint(node.annotation),
300+
self.__convert_type(node.annotation),
301301
self.__map_type(node.annotation)
302302
),
303303
self.__map_type(node)
@@ -312,7 +312,7 @@ def visit_AnnAssign(self, node):
312312
name = cast(j.Identifier, self.__convert(node.target))
313313
if node.annotation:
314314
after = self.__source_before(':')
315-
type = self.__convert_type_hint(node.annotation)
315+
type = self.__convert_type(node.annotation)
316316
else:
317317
after = Space.EMPTY
318318
type = None
@@ -446,7 +446,8 @@ def visit_With(self, node):
446446
parenthesized = self.__cursor_at('(')
447447
parens_handler = self.__push_parentheses(node, items_prefix, self._cursor) if parenthesized else None
448448

449-
resources = [self.__pad_list_element(self.__convert(r), i == len(node.items) - 1) for i, r in enumerate(node.items)]
449+
resources = [self.__pad_list_element(self.__convert(r), i == len(node.items) - 1) for i, r in
450+
enumerate(node.items)]
450451

451452
if parenthesized and self._parentheses_stack and self._parentheses_stack[-1] is parens_handler:
452453
self._cursor += 1
@@ -778,7 +779,8 @@ def visit_Store(self, node):
778779

779780
def visit_ExceptHandler(self, node):
780781
prefix = self.__source_before('except')
781-
except_type = self.__convert_type(node.type) if node.type else j.Empty(random_id(), Space.EMPTY, Markers.EMPTY)
782+
except_type = self.__convert_type(node.type) if node.type else j.Empty(random_id(), Space.EMPTY,
783+
Markers.EMPTY)
782784
if node.name:
783785
before_as = self.__source_before('as')
784786
except_type_name = self.__convert_name(node.name)
@@ -1174,7 +1176,8 @@ def visit_Call(self, node):
11741176
Markers.EMPTY,
11751177
select if isinstance(name, j.Identifier) else self.__pad_right(name, Space.EMPTY),
11761178
None,
1177-
name if isinstance(name, j.Identifier) else j.Identifier(random_id(), Space.EMPTY, Markers.EMPTY, [], "", None, None),
1179+
name if isinstance(name, j.Identifier) else j.Identifier(random_id(), Space.EMPTY, Markers.EMPTY, [], "",
1180+
None, None),
11781181
args,
11791182
self.__map_type(node)
11801183
)
@@ -1784,7 +1787,7 @@ def visit_UnaryOp(self, node):
17841787
self.__map_type(node)
17851788
)
17861789

1787-
def __convert_type_hint(self, node) -> Optional[TypeTree]:
1790+
def __convert_type(self, node) -> Optional[TypeTree]:
17881791
if isinstance(node, ast.Constant):
17891792
if node.value is None or node.value is Ellipsis:
17901793
return py.LiteralType(
@@ -1825,7 +1828,7 @@ def __convert_type_hint(self, node) -> Optional[TypeTree]:
18251828
self.__convert(node.value),
18261829
JContainer(
18271830
self.__source_before('['),
1828-
[self.__pad_list_element(self.__convert_type_hint(s), last=i == len(slices) - 1, end_delim=']') for
1831+
[self.__pad_list_element(self.__convert_type(s), last=i == len(slices) - 1, end_delim=']') for
18291832
i, s in
18301833
enumerate(slices)],
18311834
Markers.EMPTY
@@ -1837,9 +1840,9 @@ def __convert_type_hint(self, node) -> Optional[TypeTree]:
18371840
# NOTE: Type unions using `|` was added in Python 3.10
18381841
prefix = self.__whitespace()
18391842
# FIXME consider flattening nested unions
1840-
left = self.__pad_right(self.__convert_internal(node.left, self.__convert_type_hint),
1843+
left = self.__pad_right(self.__convert_internal(node.left, self.__convert_type),
18411844
self.__source_before('|'))
1842-
right = self.__pad_right(self.__convert_internal(node.right, self.__convert_type_hint), Space.EMPTY)
1845+
right = self.__pad_right(self.__convert_internal(node.right, self.__convert_type), Space.EMPTY)
18431846
return py.UnionType(
18441847
random_id(),
18451848
prefix,
@@ -1848,12 +1851,6 @@ def __convert_type_hint(self, node) -> Optional[TypeTree]:
18481851
self.__map_type(node)
18491852
)
18501853

1851-
return self.__convert_internal(node, self.__convert_type_hint)
1852-
1853-
def __convert(self, node) -> Optional[J]:
1854-
return self.__convert_internal(node, self.__convert)
1855-
1856-
def __convert_type(self, node) -> Optional[j.TypeTree]:
18571854
prefix = self.__whitespace()
18581855
converted_type = self.__convert_internal(node, self.__convert_type)
18591856
if is_of_type(converted_type, TypeTree):
@@ -1874,6 +1871,9 @@ def __convert_type(self, node) -> Optional[j.TypeTree]:
18741871
converted_type
18751872
)
18761873

1874+
def __convert(self, node) -> Optional[J]:
1875+
return self.__convert_internal(node, self.__convert)
1876+
18771877
def __convert_internal(self, node, recursion) -> Optional[J]:
18781878
if not node or not isinstance(node, ast.expr) or isinstance(node, ast.GeneratorExp):
18791879
return self.visit(cast(ast.AST, node)) if node else None

rewrite/tests/python/all/method_declaration_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def foo(**kwargs) :
125125
)
126126

127127

128+
def test_param_type_with_parens():
129+
# language=python
130+
rewrite_run(
131+
python(
132+
"""\
133+
def foo(i: (int)) :
134+
pass
135+
"""
136+
)
137+
)
138+
139+
128140
def test_one_line():
129141
# language=python
130142
rewrite_run(python("def f(x): x = x + 1; return x"))

0 commit comments

Comments
 (0)